Skip to content

Commit

Permalink
fix(router.py): simplify scheduler
Browse files Browse the repository at this point in the history
move the scheduler poll queuing logic into the router class, making it easier to use
  • Loading branch information
krrishdholakia committed Jun 1, 2024
1 parent 27087f6 commit 7715267
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 131 deletions.
56 changes: 9 additions & 47 deletions docs/my-website/docs/scheduler.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ Prioritize LLM API requests in high-traffic.
## Quick Start

```python
from litellm import Scheduler, FlowItem, Router

scheduler = Scheduler()
from litellm import Router

router = Router(
model_list=[
Expand All @@ -39,53 +37,17 @@ router = Router(
],
timeout=2, # timeout request if takes > 2s
routing_strategy="usage-based-routing-v2",
polling_interval=0.03 # poll queue every 3ms if no healthy deployments
)

scheduler.update_variables(llm_router=router)

### 🚨 IMPORTANT ###

item = FlowItem(
priority=0, # 👈 SET PRIORITY FOR REQUEST
request_id=str(uuid.uuid4()), # 👈 SET REQUEST ID
model_name="gpt-3.5-turbo" # 👈 SAME as 'Router'
)

### [fin] IMPORTANT ###

## ADDS REQUEST TO QUEUE ##
await scheduler.add_request(request=item)

## POLL QUEUE
default_timeout = router.timeout
end_time = time.time() + default_timeout
poll_interval = 0.03 # poll every 3ms
curr_time = time.time()

make_request = False

while curr_time < end_time:
make_request = await scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
id=item.request_id, model_name=item.model_name
try:
_response = await router.schedule_acompletion( # 👈 ADDS TO QUEUE + POLLS + MAKES CALL
model=item.model_name,
messages=[{"role": "user", "content": "Hey!"}],
priority=0, # 👈 LOWER IS BETTER
)
if make_request: ## IF TRUE -> MAKE REQUEST
break
else: ## ELSE -> loop till default_timeout
await asyncio.sleep(poll_interval)
curr_time = time.time()

if make_request:
try:
_response = await router.acompletion(
model=item.model_name,
messages=[{"role": "user", "content": "Hey!"}],
)
except Exception as e:
print("{}, {}, {}".format(item.priority, item.request_id, "Error occurred"))

print("{}, {}, {}".format(item.priority, item.request_id, time.time()))

print("didn't make request")
except Exception as e:
print("didn't make request")
```

## LiteLLM Proxy
Expand Down
60 changes: 3 additions & 57 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,6 @@ async def openai_exception_handler(request: Request, exc: ProxyException):
async_result = None
celery_app_conn = None
celery_fn = None # Redis Queue for handling requests
### SIMPLE QUEUE ###
simple_scheduler = Scheduler()
### DB WRITER ###
db_writer_client: Optional[HTTPHandler] = None
### logger ###
Expand Down Expand Up @@ -3705,7 +3703,7 @@ def on_backoff(details):

@router.on_event("startup")
async def startup_event():
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, simple_scheduler
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db
import json

### LOAD MASTER KEY ###
Expand Down Expand Up @@ -3741,10 +3739,6 @@ async def startup_event():
## Error Tracking ##
error_tracking()

## Priority Workload Scheduler ##
if llm_router is not None:
simple_scheduler.update_variables(llm_router=llm_router)

## UPDATE SLACK ALERTING ##
proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router)

Expand Down Expand Up @@ -12183,47 +12177,12 @@ async def async_queue_request(
if user_api_base:
data["api_base"] = user_api_base

## FLOW ITEM ##
request_id = str(uuid.uuid4())
flow_item = FlowItem(
priority=data.pop("priority", DefaultPriorities.Medium.value),
request_id=request_id,
model_name=data["model"],
)
# [TODO] only allow premium users to set non default priorities

## ADD REQUEST TO QUEUE
response = await simple_scheduler.add_request(request=flow_item)

if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
## POLL QUEUE
default_timeout = llm_router.timeout
end_time = time.time() + default_timeout
poll_interval = 0.03 # poll every 3ms
curr_time = time.time()

make_request = False

if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)

while curr_time < end_time:
make_request = await simple_scheduler.poll(
id=request_id, model_name=data["model"]
)
if make_request: ## IF TRUE -> MAKE REQUEST
break
else: ## ELSE -> loop till default_timeout
await asyncio.sleep(poll_interval)
curr_time = time.time()

if make_request:
response = await llm_router.acompletion(**data)
response = await llm_router.schedule_acompletion(**data)

if (
"stream" in data and data["stream"] == True
Expand All @@ -12237,7 +12196,7 @@ async def async_queue_request(
media_type="text/event-stream",
)

fastapi_response.headers.update({"x-litellm-priority": str(flow_item.priority)})
fastapi_response.headers.update({"x-litellm-priority": str(data["priority"])})
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
Expand All @@ -12260,19 +12219,6 @@ async def async_queue_request(
)


@router.get(
"/queue/info",
tags=["experimental"],
dependencies=[Depends(user_api_key_auth)],
)
async def queue_info(
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> List:
"""Help user know the status of an item in the queue"""
return simple_scheduler.get_queue_status()


@router.get(
"/ollama_logs", dependencies=[Depends(user_api_key_auth)], tags=["experimental"]
)
Expand Down
91 changes: 89 additions & 2 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
Run,
AssistantToolParam,
)
from litellm.scheduler import Scheduler, FlowItem
from typing import Iterable


Expand All @@ -87,6 +88,8 @@ def __init__(
List[tuple]
] = None, # if you want to cache across model groups
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
## SCHEDULER ##
polling_interval: Optional[float] = None,
## RELIABILITY ##
num_retries: Optional[int] = None,
timeout: Optional[float] = None,
Expand Down Expand Up @@ -141,7 +144,8 @@ def __init__(
cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}.
caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None.
client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600.
num_retries (int): Number of retries for failed requests. Defaults to 0.
polling_interval: (Optional[float]): frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms.
num_retries (Optional[int]): Number of retries for failed requests. Defaults to 2.
timeout (Optional[float]): Timeout for requests. Defaults to None.
default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}.
set_verbose (bool): Flag to set verbose mode. Defaults to False.
Expand Down Expand Up @@ -208,6 +212,8 @@ def __init__(
[]
) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {}
### SCHEDULER ###
self.scheduler = Scheduler(polling_interval=polling_interval)
### CACHING ###
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
redis_cache = None
Expand Down Expand Up @@ -533,11 +539,17 @@ async def acompletion(
) -> ModelResponse:
...

@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Union[Literal[True], Literal[False]] = False, **kwargs
) -> Union[CustomStreamWrapper, ModelResponse]:
...

# fmt: on

# The actual implementation of the function
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs
self, model: str, messages: List[Dict[str, str]], stream: bool = False, **kwargs
):
try:
kwargs["model"] = model
Expand Down Expand Up @@ -905,6 +917,81 @@ async def check_response(task: asyncio.Task):
# If we exit the loop without returning, all tasks failed
raise Exception("All tasks failed")

### SCHEDULER ###

# fmt: off

@overload
async def schedule_acompletion(
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...

@overload
async def schedule_acompletion(
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[True], **kwargs
) -> CustomStreamWrapper:
...

# fmt: on

async def schedule_acompletion(
self,
model: str,
messages: List[Dict[str, str]],
priority: int,
stream=False,
**kwargs,
):
### FLOW ITEM ###
_request_id = str(uuid.uuid4())
item = FlowItem(
priority=priority, # 👈 SET PRIORITY FOR REQUEST
request_id=_request_id, # 👈 SET REQUEST ID
model_name="gpt-3.5-turbo", # 👈 SAME as 'Router'
)
### [fin] ###

## ADDS REQUEST TO QUEUE ##
await self.scheduler.add_request(request=item)

## POLL QUEUE
end_time = time.time() + self.timeout
curr_time = time.time()
poll_interval = self.scheduler.polling_interval # poll every 3ms
make_request = False

while curr_time < end_time:
_healthy_deployments = await self._async_get_healthy_deployments(
model=model
)
make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
id=item.request_id,
model_name=item.model_name,
health_deployments=_healthy_deployments,
)
if make_request: ## IF TRUE -> MAKE REQUEST
break
else: ## ELSE -> loop till default_timeout
await asyncio.sleep(poll_interval)
curr_time = time.time()

if make_request:
try:
_response = await self.acompletion(
model=model, messages=messages, stream=stream, **kwargs
)
return _response
except Exception as e:
setattr(e, "priority", priority)
raise e
else:
raise litellm.Timeout(
message="Request timed out while polling queue",
model=model,
llm_provider="openai",
)

def image_generation(self, prompt: str, model: str, **kwargs):
try:
kwargs["model"] = model
Expand Down
Loading

0 comments on commit 7715267

Please sign in to comment.