From 77152679896d4179c528528aaa99a8f8f88ab152 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 1 Jun 2024 16:09:41 -0700 Subject: [PATCH] fix(router.py): simplify scheduler move the scheduler poll queuing logic into the router class, making it easier to use --- docs/my-website/docs/scheduler.md | 56 +++---------------- litellm/proxy/proxy_server.py | 60 +------------------- litellm/router.py | 91 ++++++++++++++++++++++++++++++- litellm/scheduler.py | 42 ++++++-------- litellm/tests/test_scheduler.py | 59 ++++++++++++++++++++ 5 files changed, 177 insertions(+), 131 deletions(-) diff --git a/docs/my-website/docs/scheduler.md b/docs/my-website/docs/scheduler.md index 347406ade25a..10ae6efc43b0 100644 --- a/docs/my-website/docs/scheduler.md +++ b/docs/my-website/docs/scheduler.md @@ -22,9 +22,7 @@ Prioritize LLM API requests in high-traffic. ## Quick Start ```python -from litellm import Scheduler, FlowItem, Router - -scheduler = Scheduler() +from litellm import Router router = Router( model_list=[ @@ -39,53 +37,17 @@ router = Router( ], timeout=2, # timeout request if takes > 2s routing_strategy="usage-based-routing-v2", + polling_interval=0.03 # poll queue every 3ms if no healthy deployments ) -scheduler.update_variables(llm_router=router) - -### 🚨 IMPORTANT ### - -item = FlowItem( - priority=0, # 👈 SET PRIORITY FOR REQUEST - request_id=str(uuid.uuid4()), # 👈 SET REQUEST ID - model_name="gpt-3.5-turbo" # 👈 SAME as 'Router' -) - -### [fin] IMPORTANT ### - -## ADDS REQUEST TO QUEUE ## -await scheduler.add_request(request=item) - -## POLL QUEUE -default_timeout = router.timeout -end_time = time.time() + default_timeout -poll_interval = 0.03 # poll every 3ms -curr_time = time.time() - -make_request = False - -while curr_time < end_time: - make_request = await scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue - id=item.request_id, model_name=item.model_name +try: + _response = await router.schedule_acompletion( # 👈 ADDS TO QUEUE + POLLS + MAKES CALL + model=item.model_name, + messages=[{"role": "user", "content": "Hey!"}], + priority=0, # 👈 LOWER IS BETTER ) - if make_request: ## IF TRUE -> MAKE REQUEST - break - else: ## ELSE -> loop till default_timeout - await asyncio.sleep(poll_interval) - curr_time = time.time() - -if make_request: - try: - _response = await router.acompletion( - model=item.model_name, - messages=[{"role": "user", "content": "Hey!"}], - ) - except Exception as e: - print("{}, {}, {}".format(item.priority, item.request_id, "Error occurred")) - - print("{}, {}, {}".format(item.priority, item.request_id, time.time())) - -print("didn't make request") +except Exception as e: + print("didn't make request") ``` ## LiteLLM Proxy diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 4e7fb56bdb1b..30cc9e20e4a9 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -398,8 +398,6 @@ async def openai_exception_handler(request: Request, exc: ProxyException): async_result = None celery_app_conn = None celery_fn = None # Redis Queue for handling requests -### SIMPLE QUEUE ### -simple_scheduler = Scheduler() ### DB WRITER ### db_writer_client: Optional[HTTPHandler] = None ### logger ### @@ -3705,7 +3703,7 @@ def on_backoff(details): @router.on_event("startup") async def startup_event(): - global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, simple_scheduler + global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db import json ### LOAD MASTER KEY ### @@ -3741,10 +3739,6 @@ async def startup_event(): ## Error Tracking ## error_tracking() - ## Priority Workload Scheduler ## - if llm_router is not None: - simple_scheduler.update_variables(llm_router=llm_router) - ## UPDATE SLACK ALERTING ## proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router) @@ -12183,47 +12177,12 @@ async def async_queue_request( if user_api_base: data["api_base"] = user_api_base - ## FLOW ITEM ## - request_id = str(uuid.uuid4()) - flow_item = FlowItem( - priority=data.pop("priority", DefaultPriorities.Medium.value), - request_id=request_id, - model_name=data["model"], - ) - # [TODO] only allow premium users to set non default priorities - - ## ADD REQUEST TO QUEUE - response = await simple_scheduler.add_request(request=flow_item) - - if llm_router is None: - raise HTTPException( - status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} - ) - ## POLL QUEUE - default_timeout = llm_router.timeout - end_time = time.time() + default_timeout - poll_interval = 0.03 # poll every 3ms - curr_time = time.time() - - make_request = False - if llm_router is None: raise HTTPException( status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} ) - while curr_time < end_time: - make_request = await simple_scheduler.poll( - id=request_id, model_name=data["model"] - ) - if make_request: ## IF TRUE -> MAKE REQUEST - break - else: ## ELSE -> loop till default_timeout - await asyncio.sleep(poll_interval) - curr_time = time.time() - - if make_request: - response = await llm_router.acompletion(**data) + response = await llm_router.schedule_acompletion(**data) if ( "stream" in data and data["stream"] == True @@ -12237,7 +12196,7 @@ async def async_queue_request( media_type="text/event-stream", ) - fastapi_response.headers.update({"x-litellm-priority": str(flow_item.priority)}) + fastapi_response.headers.update({"x-litellm-priority": str(data["priority"])}) return response except Exception as e: await proxy_logging_obj.post_call_failure_hook( @@ -12260,19 +12219,6 @@ async def async_queue_request( ) -@router.get( - "/queue/info", - tags=["experimental"], - dependencies=[Depends(user_api_key_auth)], -) -async def queue_info( - request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -) -> List: - """Help user know the status of an item in the queue""" - return simple_scheduler.get_queue_status() - - @router.get( "/ollama_logs", dependencies=[Depends(user_api_key_auth)], tags=["experimental"] ) diff --git a/litellm/router.py b/litellm/router.py index 88eb54a04c7b..99a61bcbf7b3 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -62,6 +62,7 @@ Run, AssistantToolParam, ) +from litellm.scheduler import Scheduler, FlowItem from typing import Iterable @@ -87,6 +88,8 @@ def __init__( List[tuple] ] = None, # if you want to cache across model groups client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds + ## SCHEDULER ## + polling_interval: Optional[float] = None, ## RELIABILITY ## num_retries: Optional[int] = None, timeout: Optional[float] = None, @@ -141,7 +144,8 @@ def __init__( cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}. caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None. client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600. - num_retries (int): Number of retries for failed requests. Defaults to 0. + polling_interval: (Optional[float]): frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms. + num_retries (Optional[int]): Number of retries for failed requests. Defaults to 2. timeout (Optional[float]): Timeout for requests. Defaults to None. default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}. set_verbose (bool): Flag to set verbose mode. Defaults to False. @@ -208,6 +212,8 @@ def __init__( [] ) # names of models under litellm_params. ex. azure/chatgpt-v-2 self.deployment_latency_map = {} + ### SCHEDULER ### + self.scheduler = Scheduler(polling_interval=polling_interval) ### CACHING ### cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache redis_cache = None @@ -533,11 +539,17 @@ async def acompletion( ) -> ModelResponse: ... + @overload + async def acompletion( + self, model: str, messages: List[Dict[str, str]], stream: Union[Literal[True], Literal[False]] = False, **kwargs + ) -> Union[CustomStreamWrapper, ModelResponse]: + ... + # fmt: on # The actual implementation of the function async def acompletion( - self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs + self, model: str, messages: List[Dict[str, str]], stream: bool = False, **kwargs ): try: kwargs["model"] = model @@ -905,6 +917,81 @@ async def check_response(task: asyncio.Task): # If we exit the loop without returning, all tasks failed raise Exception("All tasks failed") + ### SCHEDULER ### + + # fmt: off + + @overload + async def schedule_acompletion( + self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[False] = False, **kwargs + ) -> ModelResponse: + ... + + @overload + async def schedule_acompletion( + self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[True], **kwargs + ) -> CustomStreamWrapper: + ... + + # fmt: on + + async def schedule_acompletion( + self, + model: str, + messages: List[Dict[str, str]], + priority: int, + stream=False, + **kwargs, + ): + ### FLOW ITEM ### + _request_id = str(uuid.uuid4()) + item = FlowItem( + priority=priority, # 👈 SET PRIORITY FOR REQUEST + request_id=_request_id, # 👈 SET REQUEST ID + model_name="gpt-3.5-turbo", # 👈 SAME as 'Router' + ) + ### [fin] ### + + ## ADDS REQUEST TO QUEUE ## + await self.scheduler.add_request(request=item) + + ## POLL QUEUE + end_time = time.time() + self.timeout + curr_time = time.time() + poll_interval = self.scheduler.polling_interval # poll every 3ms + make_request = False + + while curr_time < end_time: + _healthy_deployments = await self._async_get_healthy_deployments( + model=model + ) + make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue + id=item.request_id, + model_name=item.model_name, + health_deployments=_healthy_deployments, + ) + if make_request: ## IF TRUE -> MAKE REQUEST + break + else: ## ELSE -> loop till default_timeout + await asyncio.sleep(poll_interval) + curr_time = time.time() + + if make_request: + try: + _response = await self.acompletion( + model=model, messages=messages, stream=stream, **kwargs + ) + return _response + except Exception as e: + setattr(e, "priority", priority) + raise e + else: + raise litellm.Timeout( + message="Request timed out while polling queue", + model=model, + llm_provider="openai", + ) + def image_generation(self, prompt: str, model: str, **kwargs): try: kwargs["model"] = model diff --git a/litellm/scheduler.py b/litellm/scheduler.py index 3bbd3916e607..605bcbb23445 100644 --- a/litellm/scheduler.py +++ b/litellm/scheduler.py @@ -3,7 +3,6 @@ from typing import Optional import enum from litellm.caching import DualCache -from litellm import Router from litellm import print_verbose @@ -25,14 +24,16 @@ class FlowItem(BaseModel): class Scheduler: cache: DualCache - llm_router: Optional[Router] = None - def __init__(self): - self.queue = [] + def __init__(self, polling_interval: Optional[float] = None): + """ + polling_interval: float or null - frequency of polling queue. Default is 3ms. + """ + self.queue: list = [] self.cache = DualCache() + self.polling_interval = polling_interval or 0.03 # default to 3ms - def update_variables(self, llm_router: Router, cache: Optional[DualCache] = None): - self.llm_router = llm_router + def update_variables(self, cache: Optional[DualCache] = None): if cache is not None: self.cache = cache @@ -46,7 +47,7 @@ async def add_request(self, request: FlowItem): # save the queue await self.save_queue(queue=queue, model_name=request.model_name) - async def poll(self, id: str, model_name: str) -> bool: + async def poll(self, id: str, model_name: str, health_deployments: list) -> bool: """ Return if request can be processed. @@ -59,22 +60,17 @@ async def poll(self, id: str, model_name: str) -> bool: * AND request not at the top of queue """ queue = await self.get_queue(model_name=model_name) - if not queue or not self.llm_router: + if not queue: raise Exception( - "Incorrectly setup. Queue or Router is invalid. Queue={}, Router={}".format( - queue, self.llm_router - ) + "Incorrectly setup. Queue is invalid. Queue={}".format(queue) ) # ------------ # Setup values # ------------ - _healthy_deployments = await self.llm_router._async_get_healthy_deployments( - model=model_name - ) - print_verbose(f"len(_healthy_deployments): {len(_healthy_deployments)}") - if len(_healthy_deployments) == 0: + print_verbose(f"len(health_deployments): {len(health_deployments)}") + if len(health_deployments) == 0: print_verbose(f"queue: {queue}, seeking id={id}") # Check if the id is at the top of the heap if queue[0][1] == id: @@ -87,23 +83,19 @@ async def poll(self, id: str, model_name: str) -> bool: return True - async def peek(self, id: str, model_name: str) -> bool: + async def peek(self, id: str, model_name: str, health_deployments: list) -> bool: """Return if the id is at the top of the queue. Don't pop the value from heap.""" queue = await self.get_queue(model_name=model_name) - if not queue or not self.llm_router: + if not queue: raise Exception( - "Incorrectly setup. Queue or Router is invalid. Queue={}, Router={}".format( - queue, self.llm_router - ) + "Incorrectly setup. Queue is invalid. Queue={}".format(queue) ) # ------------ # Setup values # ------------ - _healthy_deployments = await self.llm_router._async_get_healthy_deployments( - model=model_name - ) - if len(_healthy_deployments) == 0: + + if len(health_deployments) == 0: return False # Check if the id is at the top of the heap diff --git a/litellm/tests/test_scheduler.py b/litellm/tests/test_scheduler.py index 20756c4fa4be..43fe85de1a7d 100644 --- a/litellm/tests/test_scheduler.py +++ b/litellm/tests/test_scheduler.py @@ -4,12 +4,14 @@ import sys, os, time, openai, uuid import traceback, asyncio import pytest +from typing import List sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path from litellm import Router from litellm.scheduler import FlowItem, Scheduler +from litellm import ModelResponse @pytest.mark.asyncio @@ -172,3 +174,60 @@ async def _make_prioritized_call(flow_item: FlowItem): assert ( completed_responses[0][2] < completed_responses[1][2] ) # higher priority request tried first + + +@pytest.mark.parametrize("p0, p1", [(0, 1), (0, 0)]) # +@pytest.mark.asyncio +async def test_aascheduler_prioritized_requests_mock_response_simplified(p0, p1): + """ + 2 requests for same model group + + if model is at rate limit, ensure the higher priority request gets done first + """ + + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "mock_response": "Hello world this is Macintosh!", + "rpm": 0, + }, + }, + ], + timeout=10, + num_retries=3, + cooldown_time=5, + routing_strategy="usage-based-routing-v2", + ) + + tasks = [] + + data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hey, how's it going?"}], + } + + tasks.append(router.schedule_acompletion(**data, priority=p0)) + tasks.append(router.schedule_acompletion(**data, priority=p1)) + + # Running the tasks and getting responses in order of completion + completed_responses: List[dict] = [] + for task in asyncio.as_completed(tasks): + try: + result = await task + except Exception as e: + result = {"priority": e.priority, "response_completed_at": time.time()} + completed_responses.append(result) + print(f"Received response: {result}") + + print(f"responses: {completed_responses}") + + assert ( + completed_responses[0]["priority"] == 0 + ) # assert higher priority request got done first + assert ( + completed_responses[0]["response_completed_at"] + < completed_responses[1]["response_completed_at"] + ) # higher priority request tried first