Skip to content

Commit

Permalink
Merge pull request #1023 from PSU3D0/speedup_health_endpoint
Browse files Browse the repository at this point in the history
(feat) Speedup health endpoint
  • Loading branch information
ishaan-jaff authored Dec 6, 2023
2 parents bd05797 + 95e5331 commit a4cf4e7
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 27 deletions.
2 changes: 1 addition & 1 deletion litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import threading, requests
from typing import Callable, List, Optional, Dict, Union, Any
from litellm.caching import Cache
from litellm._logging import set_verbose
import httpx

input_callback: List[Union[str, Callable]] = []
Expand All @@ -11,7 +12,6 @@
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
set_verbose = False
email: Optional[
str
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
Expand Down
5 changes: 5 additions & 0 deletions litellm/_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set_verbose = False

def print_verbose(print_statement):
if set_verbose:
print(print_statement)
115 changes: 115 additions & 0 deletions litellm/health_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import asyncio
import random
from typing import Optional

import litellm
import logging
from litellm._logging import print_verbose


logger = logging.getLogger(__name__)


ILLEGAL_DISPLAY_PARAMS = [
"messages",
"api_key"
]


def _get_random_llm_message():
"""
Get a random message from the LLM.
"""
messages = [
"Hey how's it going?",
"What's 1 + 1?"
]


return [
{"role": "user", "content": random.choice(messages)}
]


def _clean_litellm_params(litellm_params: dict):
"""
Clean the litellm params for display to users.
"""
return {k: v for k, v in litellm_params.items() if k not in ILLEGAL_DISPLAY_PARAMS}


async def _perform_health_check(model_list: list):
"""
Perform a health check for each model in the list.
"""

async def _check_model(model_params: dict):
try:
await litellm.acompletion(**model_params)
except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
return False

return True

prepped_params = []

for model in model_list:
litellm_params = model["litellm_params"]
litellm_params["model"] = litellm.utils.remove_model_id(litellm_params["model"])
litellm_params["messages"] = _get_random_llm_message()

prepped_params.append(litellm_params)


tasks = [_check_model(x) for x in prepped_params]

results = await asyncio.gather(*tasks)

healthy_endpoints = []
unhealthy_endpoints = []

for is_healthy, model in zip(results, model_list):
cleaned_litellm_params = _clean_litellm_params(model["litellm_params"])

if is_healthy:
healthy_endpoints.append(cleaned_litellm_params)
else:
unhealthy_endpoints.append(cleaned_litellm_params)

return healthy_endpoints, unhealthy_endpoints




async def perform_health_check(model_list: list, model: Optional[str] = None):
"""
Perform a health check on the system.
Returns:
(bool): True if the health check passes, False otherwise.
"""
if not model_list:
return [], []

if model is not None:
model_list = [x for x in model_list if x["litellm_params"]["model"] == model]

models_to_check = []

for model in model_list:
litellm_params = model["litellm_params"]
model_name = litellm.utils.remove_model_id(litellm_params["model"])

if model_name in litellm.all_embedding_models:
continue # Skip embedding models


models_to_check.append(model)


healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list)

return healthy_endpoints, unhealthy_endpoints


34 changes: 8 additions & 26 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import hashlib, uuid
import warnings
import importlib

from litellm.health_check import perform_health_check
messages: list = []
sys.path.insert(
0, os.path.abspath("../..")
Expand Down Expand Up @@ -1173,34 +1175,14 @@ async def test_endpoint(request: Request):
@router.get("/health", description="Check the health of all the endpoints in config.yaml", tags=["health"])
async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query(None, description="Specify the model name (optional)")):
global llm_model_list
healthy_endpoints = []
unhealthy_endpoints = []
if llm_model_list:
for model_name in llm_model_list:
try:
if model is None or model == model_name["litellm_params"]["model"]: # if model specified, just call that one.
litellm_params = model_name["litellm_params"]
model_name = litellm.utils.remove_model_id(litellm_params["model"]) # removes, ids set by litellm.router
if model_name not in litellm.all_embedding_models: # filter out embedding models
litellm_params["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
litellm_params["model"] = model_name
litellm.completion(**litellm_params)
cleaned_params = {}
for key in litellm_params:
if key != "api_key" and key != "messages":
cleaned_params[key] = litellm_params[key]
healthy_endpoints.append(cleaned_params)
except Exception as e:
print("Got Exception", e)
cleaned_params = {}
for key in litellm_params:
if key != "api_key" and key != "messages":
cleaned_params[key] = litellm_params[key]
unhealthy_endpoints.append(cleaned_params)
pass

healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model)

return {
"healthy_endpoints": healthy_endpoints,
"unhealthy_endpoints": unhealthy_endpoints
"unhealthy_endpoints": unhealthy_endpoints,
"healthy_count": len(healthy_endpoints),
"unhealthy_count": len(unhealthy_endpoints),
}

@router.get("/")
Expand Down

0 comments on commit a4cf4e7

Please sign in to comment.