Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat) Speedup health endpoint #1023

Merged
merged 2 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import threading, requests
from typing import Callable, List, Optional, Dict, Union, Any
from litellm.caching import Cache
from litellm._logging import set_verbose
import httpx

input_callback: List[Union[str, Callable]] = []
Expand All @@ -11,7 +12,6 @@
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
set_verbose = False
email: Optional[
str
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
Expand Down
5 changes: 5 additions & 0 deletions litellm/_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set_verbose = False

def print_verbose(print_statement):
if set_verbose:
print(print_statement)
115 changes: 115 additions & 0 deletions litellm/health_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import asyncio
import random
from typing import Optional

import litellm
import logging
from litellm._logging import print_verbose


logger = logging.getLogger(__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the purpose of this logger



ILLEGAL_DISPLAY_PARAMS = [
"messages",
"api_key"
]


def _get_random_llm_message():
"""
Get a random message from the LLM.
"""
messages = [
"Hey how's it going?",
"What's 1 + 1?"
]


return [
{"role": "user", "content": random.choice(messages)}
]


def _clean_litellm_params(litellm_params: dict):
"""
Clean the litellm params for display to users.
"""
return {k: v for k, v in litellm_params.items() if k not in ILLEGAL_DISPLAY_PARAMS}


async def _perform_health_check(model_list: list):
"""
Perform a health check for each model in the list.
"""

async def _check_model(model_params: dict):
try:
await litellm.acompletion(**model_params)
except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
return False

return True

prepped_params = []

for model in model_list:
litellm_params = model["litellm_params"]
litellm_params["model"] = litellm.utils.remove_model_id(litellm_params["model"])
litellm_params["messages"] = _get_random_llm_message()

prepped_params.append(litellm_params)


tasks = [_check_model(x) for x in prepped_params]

results = await asyncio.gather(*tasks)

healthy_endpoints = []
unhealthy_endpoints = []

for is_healthy, model in zip(results, model_list):
cleaned_litellm_params = _clean_litellm_params(model["litellm_params"])

if is_healthy:
healthy_endpoints.append(cleaned_litellm_params)
else:
unhealthy_endpoints.append(cleaned_litellm_params)

return healthy_endpoints, unhealthy_endpoints




async def perform_health_check(model_list: list, model: Optional[str] = None):
"""
Perform a health check on the system.

Returns:
(bool): True if the health check passes, False otherwise.
"""
if not model_list:
return [], []

if model is not None:
model_list = [x for x in model_list if x["litellm_params"]["model"] == model]

models_to_check = []

for model in model_list:
litellm_params = model["litellm_params"]
model_name = litellm.utils.remove_model_id(litellm_params["model"])

if model_name in litellm.all_embedding_models:
continue # Skip embedding models


models_to_check.append(model)


healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list)

return healthy_endpoints, unhealthy_endpoints


34 changes: 8 additions & 26 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import hashlib, uuid
import warnings
import importlib

from litellm.health_check import perform_health_check
messages: list = []
sys.path.insert(
0, os.path.abspath("../..")
Expand Down Expand Up @@ -1131,34 +1133,14 @@ async def test_endpoint(request: Request):
@router.get("/health", description="Check the health of all the endpoints in config.yaml", tags=["health"])
async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query(None, description="Specify the model name (optional)")):
global llm_model_list
healthy_endpoints = []
unhealthy_endpoints = []
if llm_model_list:
for model_name in llm_model_list:
try:
if model is None or model == model_name["litellm_params"]["model"]: # if model specified, just call that one.
litellm_params = model_name["litellm_params"]
model_name = litellm.utils.remove_model_id(litellm_params["model"]) # removes, ids set by litellm.router
if model_name not in litellm.all_embedding_models: # filter out embedding models
litellm_params["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
litellm_params["model"] = model_name
litellm.completion(**litellm_params)
cleaned_params = {}
for key in litellm_params:
if key != "api_key" and key != "messages":
cleaned_params[key] = litellm_params[key]
healthy_endpoints.append(cleaned_params)
except Exception as e:
print("Got Exception", e)
cleaned_params = {}
for key in litellm_params:
if key != "api_key" and key != "messages":
cleaned_params[key] = litellm_params[key]
unhealthy_endpoints.append(cleaned_params)
pass

healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model)

return {
"healthy_endpoints": healthy_endpoints,
"unhealthy_endpoints": unhealthy_endpoints
"unhealthy_endpoints": unhealthy_endpoints,
"healthy_count": len(healthy_endpoints),
"unhealthy_count": len(unhealthy_endpoints),
}

@router.get("/")
Expand Down