Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add model aliases #122

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ Merci, avant chaque pull request, de vérifier le bon déploiement de votre API
1. Lancez l'API en local avec la commande suivante:

```bash
DEFAULT_RATE_LIMIT="30/minute" uvicorn app.main:app --port 8080 --log-level debug --reload
uvicorn app.main:app --port 8080 --log-level debug --reload
```

2. Exécutez les tests unitaires à la racine du projet

```bash
DEFAULT_RATE_LIMIT="30/minute" PYTHONPATH=. pytest --config-file=pyproject.toml --base-url http://localhost:8080/v1 --api-key-user API_KEY_USER --api-key-admin API_KEY_ADMIN --log-cli-level=INFO
PYTHONPATH=. pytest --config-file=pyproject.toml --base-url http://localhost:8080/v1 --api-key-user API_KEY_USER --api-key-admin API_KEY_ADMIN --log-cli-level=INFO
```

# Notebooks
Expand Down
51 changes: 46 additions & 5 deletions app/clients/_modelclients.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import requests

from app.schemas.embeddings import Embeddings
from app.schemas.chat import ChatCompletion
from app.schemas.models import Model, Models
from app.schemas.rerank import Rerank
from app.schemas.settings import Settings
Expand Down Expand Up @@ -69,15 +70,36 @@ def get_models_list(self, *args, **kwargs) -> Models:
owned_by=self.owned_by,
created=self.created,
max_context_length=self.max_context_length,
aliases=self.aliases,
type=self.type,
status=self.status,
)

return Models(data=[data])


def create_chat_completions(self, *args, **kwargs):
"""
Custom method to overwrite OpenAI's create method to raise HTTPException from model API.
"""
try:
url = f"{self.base_url}chat/completions"
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post(url=url, headers=headers, json=kwargs)
response.raise_for_status()
data = response.json()

return ChatCompletion(**data)

except Exception as e:
raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"])


# @TODO : useless ?
def create_embeddings(self, *args, **kwargs):
"""
Custom method to overwrite OpenAI's create method to raise HTTPException from model API.
"""
try:
url = f"{self.base_url}embeddings"
headers = {"Authorization": f"Bearer {self.api_key}"}
Expand All @@ -104,13 +126,17 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDI
# set attributes for unavailable models
self.id = ""
self.owned_by = ""
self.aliases = []
self.created = round(number=time.time())
self.max_context_length = None

# set real attributes if model is available
self.models.list = partial(get_models_list, self)
response = self.models.list()

if self.type == LANGUAGE_MODEL_TYPE:
self.chat.completions.create = partial(create_chat_completions, self)

if self.type == EMBEDDINGS_MODEL_TYPE:
response = self.embeddings.create(model=self.id, input="hello world")
self.vector_size = len(response.data[0].embedding)
Expand Down Expand Up @@ -145,25 +171,40 @@ class ModelClients(dict):
"""

def __init__(self, settings: Settings) -> None:
for model_config in settings.models:
model = ModelClient(base_url=model_config.url, api_key=model_config.key, type=model_config.type)
self.aliases = {alias: model_id for model_id, aliases in settings.models.aliases.items() for alias in aliases}

for model_settings in settings.clients.models:
model = ModelClient(
base_url=model_settings.url,
api_key=model_settings.key,
type=model_settings.type,
)
if model.status == "unavailable":
logger.error(msg=f"unavailable model API on {model_config.url}, skipping.")
logger.error(msg=f"unavailable model API on {model_settings.url}, skipping.")
continue
try:
logger.info(msg=f"Adding model API {model_config.url} to the client...")
logger.info(msg=f"Adding model API {model_settings.url} to the client...")
self.__setitem__(key=model.id, value=model)
logger.info(msg="done.")
except Exception as e:
logger.error(msg=e)

model.aliases = settings.models.aliases.get(model.id, [])

for alias in self.aliases.keys():
assert alias not in self.keys(), "Alias is already used by another model."

assert settings.internet.default_language_model in self.keys(), "Default internet language model not found."
assert settings.internet.default_embeddings_model in self.keys(), "Default internet embeddings model not found."

def __setitem__(self, key: str, value) -> None:
if any(key == k for k in self.keys()):
if key in self.keys():
raise ValueError(f"duplicated model ID {key}, skipping.")
else:
super().__setitem__(key, value)

def __getitem__(self, key: str) -> Any:
key = self.aliases.get(key, key)
try:
item = super().__getitem__(key)
assert item.status == "available", "Model not available."
Expand Down
5 changes: 2 additions & 3 deletions app/endpoints/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import httpx

from app.schemas.audio import AudioTranscription
from app.schemas.settings import AUDIO_MODEL_TYPE
from app.utils.exceptions import ModelNotFoundException
from app.utils.lifespan import clients, limiter
from app.utils.security import User, check_api_key, check_rate_limit
from app.utils.settings import settings
from app.utils.variables import DEFAULT_TIMEOUT
from app.utils.variables import DEFAULT_TIMEOUT, AUDIO_MODEL_TYPE

router = APIRouter()

Expand Down Expand Up @@ -135,7 +134,7 @@


@router.post("/audio/transcriptions")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def audio_transcriptions(
request: Request,
file: UploadFile = File(...),
Expand Down
13 changes: 9 additions & 4 deletions app/endpoints/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,28 @@
from app.schemas.search import Search
from app.schemas.security import User
from app.schemas.settings import Settings
from app.utils.exceptions import WrongModelTypeException
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.settings import settings
from app.utils.variables import DEFAULT_TIMEOUT
from app.utils.variables import DEFAULT_TIMEOUT, LANGUAGE_MODEL_TYPE

router = APIRouter()


@router.post(path="/chat/completions")
@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(limit_value=settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def chat_completions(
request: Request, body: ChatCompletionRequest, user: User = Security(dependency=check_api_key)
) -> Union[ChatCompletion, ChatCompletionChunk]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create for the API specification.
"""

client = clients.models[body.model]
if client.type != LANGUAGE_MODEL_TYPE:
raise WrongModelTypeException()
body.model = client.id # replace alias by model id
url = f"{client.base_url}chat/completions"
headers = {"Authorization": f"Bearer {client.api_key}"}

Expand All @@ -42,8 +47,8 @@ def retrieval_augmentation_generation(
internet_manager=InternetManager(
model_clients=clients.models,
internet_client=clients.internet,
default_language_model_id=settings.internet.args.default_language_model,
default_embeddings_model_id=settings.internet.args.default_embeddings_model,
default_language_model_id=settings.internet.default_language_model,
default_embeddings_model_id=settings.internet.default_embeddings_model,
),
)
searches = search_manager.query(
Expand Down
2 changes: 1 addition & 1 deletion app/endpoints/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@router.get("/chunks/{collection}/{document}")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def get_chunks(
request: Request,
collection: UUID,
Expand Down
6 changes: 3 additions & 3 deletions app/endpoints/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@router.post("/collections")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def create_collection(request: Request, body: CollectionRequest, user: User = Security(check_api_key)) -> Response:
"""
Create a new collection.
Expand All @@ -35,7 +35,7 @@ async def create_collection(request: Request, body: CollectionRequest, user: Use


@router.get("/collections")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def get_collections(request: Request, user: User = Security(check_api_key)) -> Union[Collection, Collections]:
"""
Get list of collections.
Expand All @@ -54,7 +54,7 @@ async def get_collections(request: Request, user: User = Security(check_api_key)


@router.delete("/collections/{collection}")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def delete_collections(request: Request, collection: UUID, user: User = Security(check_api_key)) -> Response:
"""
Delete a collection.
Expand Down
15 changes: 10 additions & 5 deletions app/endpoints/completions.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from fastapi import APIRouter, Request, Security, HTTPException
import httpx
import json

from fastapi import APIRouter, HTTPException, Request, Security
import httpx

from app.schemas.completions import CompletionRequest, Completions
from app.schemas.security import User
from app.utils.settings import settings
from app.utils.exceptions import WrongModelTypeException
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.variables import DEFAULT_TIMEOUT
from app.utils.settings import settings
from app.utils.variables import DEFAULT_TIMEOUT, LANGUAGE_MODEL_TYPE

router = APIRouter()


@router.post(path="/completions")
@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(limit_value=settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def completions(request: Request, body: CompletionRequest, user: User = Security(dependency=check_api_key)) -> Completions:
"""
Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create for the API specification.
"""
client = clients.models[body.model]
if client.type != LANGUAGE_MODEL_TYPE:
raise WrongModelTypeException()
body.model = client.id # replace alias by model id
url = f"{client.base_url}completions"
headers = {"Authorization": f"Bearer {client.api_key}"}

Expand Down
4 changes: 2 additions & 2 deletions app/endpoints/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@router.get("/documents/{collection}")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def get_documents(
request: Request,
collection: UUID,
Expand All @@ -31,7 +31,7 @@ async def get_documents(


@router.delete("/documents/{collection}/{document}")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def delete_document(request: Request, collection: UUID, document: UUID, user: User = Security(check_api_key)) -> Response:
"""
Delete a document and relative collections.
Expand Down
9 changes: 2 additions & 7 deletions app/endpoints/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@router.post(path="/embeddings")
@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(limit_value=settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Security(dependency=check_api_key)) -> Embeddings:
"""
Embedding API similar to OpenAI's API.
Expand All @@ -24,19 +24,14 @@ async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Sec
client = clients.models[body.model]
if client.type != EMBEDDINGS_MODEL_TYPE:
raise WrongModelTypeException()

body.model = client.id # replace alias by model id
url = f"{client.base_url}embeddings"
headers = {"Authorization": f"Bearer {client.api_key}"}

try:
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump())
# try:
response.raise_for_status()
# except httpx.HTTPStatusError as e:
# if "`inputs` must have less than" in e.response.text:
# raise ContextLengthExceededException()
# raise e
data = response.json()
return Embeddings(**data)
except Exception as e:
Expand Down
6 changes: 3 additions & 3 deletions app/endpoints/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

@router.get("/models/{model:path}")
@router.get("/models")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def models(request: Request, model: Optional[str] = None, user: User = Security(check_api_key)) -> Union[Models, Model]:
"""
Model API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/models/list for the API specification.
"""
if model is not None:
client = clients.models[model]
response = [row for row in client.models.list().data if row.id == model][0]
model = clients.models[model]
response = [row for row in model.models.list().data if row.id == model.id][0]
else:
response = {"object": "list", "data": []}
for model_id, client in clients.models.items():
Expand Down
11 changes: 6 additions & 5 deletions app/endpoints/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@


@router.post("/rerank")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def rerank(request: Request, body: RerankRequest, user: User = Security(check_api_key)):
"""
Rerank a list of inputs with a language model or reranker model.
"""
model = clients.models[body.model]

if clients.models[body.model].type == LANGUAGE_MODEL_TYPE:
reranker = LanguageModelReranker(model=clients.models[body.model])
if model.type == LANGUAGE_MODEL_TYPE:
reranker = LanguageModelReranker(model=model)
data = reranker.create(prompt=body.prompt, input=body.input)
elif clients.models[body.model].type == RERANK_MODEL_TYPE:
data = clients.models[body.model].rerank.create(prompt=body.prompt, input=body.input, model=body.model)
elif model.type == RERANK_MODEL_TYPE:
data = model.rerank.create(prompt=body.prompt, input=body.input, model=model.id)
else:
raise WrongModelTypeException()

Expand Down
6 changes: 3 additions & 3 deletions app/endpoints/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@router.post(path="/search")
@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(limit_value=settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def search(request: Request, body: SearchRequest, user: User = Security(dependency=check_api_key)) -> Searches:
"""
Endpoint to search on the internet or with our search client.
Expand All @@ -26,8 +26,8 @@ async def search(request: Request, body: SearchRequest, user: User = Security(de
internet_manager=InternetManager(
model_clients=clients.models,
internet_client=clients.internet,
default_language_model_id=settings.internet.args.default_language_model,
default_embeddings_model_id=settings.internet.args.default_embeddings_model,
default_language_model_id=settings.internet.default_language_model,
default_embeddings_model_id=settings.internet.default_embeddings_model,
),
)

Expand Down
Loading