From fc31221b8a5bc5ba4790942d62a511cf32ceb8ac Mon Sep 17 00:00:00 2001 From: Frank Colson Date: Tue, 5 Dec 2023 22:09:01 -0700 Subject: [PATCH 1/2] Speedup health endpoint --- litellm/health_check.py | 115 ++++++++++++++++++++++++++++++++++ litellm/proxy/proxy_server.py | 34 +++------- 2 files changed, 123 insertions(+), 26 deletions(-) create mode 100644 litellm/health_check.py diff --git a/litellm/health_check.py b/litellm/health_check.py new file mode 100644 index 000000000000..08cbffed53e6 --- /dev/null +++ b/litellm/health_check.py @@ -0,0 +1,115 @@ +import asyncio +import random +from typing import Optional + +import litellm +import logging +from concurrent.futures import ThreadPoolExecutor + + +logger = logging.getLogger(__name__) + + +ILLEGAL_DISPLAY_PARAMS = [ + "messages", + "api_key" +] + + +def _get_random_llm_message(): + """ + Get a random message from the LLM. + """ + messages = [ + "Hey how's it going?", + "What's 1 + 1?" + ] + + + return [ + {"role": "user", "content": random.choice(messages)} + ] + + +def _clean_litellm_params(litellm_params: dict): + """ + Clean the litellm params for display to users. + """ + return {k: v for k, v in litellm_params.items() if k not in ILLEGAL_DISPLAY_PARAMS} + + +async def _perform_health_check(model_list: list): + """ + Perform a health check for each model in the list. + """ + + async def _check_model(model_params: dict): + try: + await litellm.acompletion(**model_params) + except Exception as e: + logger.exception("Health check failed for model %s", model_params["model"]) + return False + + return True + + prepped_params = [] + + for model in model_list: + litellm_params = model["litellm_params"] + litellm_params["model"] = litellm.utils.remove_model_id(litellm_params["model"]) + litellm_params["messages"] = _get_random_llm_message() + + prepped_params.append(litellm_params) + + + tasks = [_check_model(x) for x in prepped_params] + + results = await asyncio.gather(*tasks) + + healthy_endpoints = [] + unhealthy_endpoints = [] + + for is_healthy, model in zip(results, model_list): + cleaned_litellm_params = _clean_litellm_params(model["litellm_params"]) + + if is_healthy: + healthy_endpoints.append(cleaned_litellm_params) + else: + unhealthy_endpoints.append(cleaned_litellm_params) + + return healthy_endpoints, unhealthy_endpoints + + + + +async def perform_health_check(model_list: list, model: Optional[str] = None): + """ + Perform a health check on the system. + + Returns: + (bool): True if the health check passes, False otherwise. + """ + if not model_list: + return [], [] + + if model is not None: + model_list = [x for x in model_list if x["litellm_params"]["model"] == model] + + models_to_check = [] + + for model in model_list: + litellm_params = model["litellm_params"] + model_name = litellm.utils.remove_model_id(litellm_params["model"]) + + if model_name in litellm.all_embedding_models: + continue # Skip embedding models + + + models_to_check.append(model) + + + healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list) + + return healthy_endpoints, unhealthy_endpoints + + \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6ee98a31ecac..09176a35c648 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -7,6 +7,8 @@ import hashlib, uuid import warnings import importlib + +from litellm.health_check import perform_health_check messages: list = [] sys.path.insert( 0, os.path.abspath("../..") @@ -1131,34 +1133,14 @@ async def test_endpoint(request: Request): @router.get("/health", description="Check the health of all the endpoints in config.yaml", tags=["health"]) async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query(None, description="Specify the model name (optional)")): global llm_model_list - healthy_endpoints = [] - unhealthy_endpoints = [] - if llm_model_list: - for model_name in llm_model_list: - try: - if model is None or model == model_name["litellm_params"]["model"]: # if model specified, just call that one. - litellm_params = model_name["litellm_params"] - model_name = litellm.utils.remove_model_id(litellm_params["model"]) # removes, ids set by litellm.router - if model_name not in litellm.all_embedding_models: # filter out embedding models - litellm_params["messages"] = [{"role": "user", "content": "Hey, how's it going?"}] - litellm_params["model"] = model_name - litellm.completion(**litellm_params) - cleaned_params = {} - for key in litellm_params: - if key != "api_key" and key != "messages": - cleaned_params[key] = litellm_params[key] - healthy_endpoints.append(cleaned_params) - except Exception as e: - print("Got Exception", e) - cleaned_params = {} - for key in litellm_params: - if key != "api_key" and key != "messages": - cleaned_params[key] = litellm_params[key] - unhealthy_endpoints.append(cleaned_params) - pass + + healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model) + return { "healthy_endpoints": healthy_endpoints, - "unhealthy_endpoints": unhealthy_endpoints + "unhealthy_endpoints": unhealthy_endpoints, + "healthy_count": len(healthy_endpoints), + "unhealthy_count": len(unhealthy_endpoints), } @router.get("/") From 95e5331090d658c9de5be353ba19950d0a16c94f Mon Sep 17 00:00:00 2001 From: Frank Colson Date: Tue, 5 Dec 2023 22:28:23 -0700 Subject: [PATCH 2/2] Use litellm logging convention --- litellm/__init__.py | 2 +- litellm/_logging.py | 5 +++++ litellm/health_check.py | 4 ++-- 3 files changed, 8 insertions(+), 3 deletions(-) create mode 100644 litellm/_logging.py diff --git a/litellm/__init__.py b/litellm/__init__.py index b9cf85a55e42..837ca2b0132f 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -2,6 +2,7 @@ import threading, requests from typing import Callable, List, Optional, Dict, Union, Any from litellm.caching import Cache +from litellm._logging import set_verbose import httpx input_callback: List[Union[str, Callable]] = [] @@ -11,7 +12,6 @@ _async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] -set_verbose = False email: Optional[ str ] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 diff --git a/litellm/_logging.py b/litellm/_logging.py new file mode 100644 index 000000000000..fd21d22e3559 --- /dev/null +++ b/litellm/_logging.py @@ -0,0 +1,5 @@ +set_verbose = False + +def print_verbose(print_statement): + if set_verbose: + print(print_statement) \ No newline at end of file diff --git a/litellm/health_check.py b/litellm/health_check.py index 08cbffed53e6..308382347d6c 100644 --- a/litellm/health_check.py +++ b/litellm/health_check.py @@ -4,7 +4,7 @@ import litellm import logging -from concurrent.futures import ThreadPoolExecutor +from litellm._logging import print_verbose logger = logging.getLogger(__name__) @@ -47,7 +47,7 @@ async def _check_model(model_params: dict): try: await litellm.acompletion(**model_params) except Exception as e: - logger.exception("Health check failed for model %s", model_params["model"]) + print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}") return False return True