Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(proxy_server.py): enable aggregate queries via /spend/keys #1901

Merged
merged 2 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_llm_provider,
get_api_key,
mock_completion_streaming_obj,
async_mock_completion_streaming_obj,
convert_to_model_response_object,
token_counter,
Usage,
Expand Down Expand Up @@ -307,6 +308,7 @@ def mock_completion(
messages: List,
stream: Optional[bool] = False,
mock_response: str = "This is a mock request",
logging=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -335,6 +337,15 @@ def mock_completion(
model_response = ModelResponse(stream=stream)
if stream is True:
# don't try to access stream object,
if kwargs.get("acompletion", False) == True:
return CustomStreamWrapper(
completion_stream=async_mock_completion_streaming_obj(
model_response, mock_response=mock_response, model=model
),
model=model,
custom_llm_provider="openai",
logging_obj=logging,
)
response = mock_completion_streaming_obj(
model_response, mock_response=mock_response, model=model
)
Expand Down Expand Up @@ -717,7 +728,12 @@ def completion(
)
if mock_response:
return mock_completion(
model, messages, stream=stream, mock_response=mock_response
model,
messages,
stream=stream,
mock_response=mock_response,
logging=logging,
acompletion=acompletion,
)
if custom_llm_provider == "azure":
# azure configs
Expand Down
4 changes: 2 additions & 2 deletions litellm/proxy/hooks/parallel_request_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
# ------------

new_val = {
"current_requests": current["current_requests"] - 1,
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"] + 1,
}
Expand Down Expand Up @@ -183,7 +183,7 @@ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_ti
}

new_val = {
"current_requests": current["current_requests"] - 1,
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"],
"current_rpm": current["current_rpm"],
}
Expand Down
47 changes: 44 additions & 3 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3015,7 +3015,16 @@ async def info_key_fn(
tags=["budget & spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def spend_key_fn():
async def spend_key_fn(
start_date: Optional[str] = fastapi.Query(
default=None,
description="Time from which to start viewing key spend",
),
end_date: Optional[str] = fastapi.Query(
default=None,
description="Time till which to view key spend",
),
):
"""
View all keys created, ordered by spend

Expand All @@ -3032,9 +3041,41 @@ async def spend_key_fn():
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)

key_info = await prisma_client.get_data(table_name="key", query_type="find_all")
if (
start_date is not None
and isinstance(start_date, str)
and end_date is not None
and isinstance(end_date, str)
):
# Convert the date strings to datetime objects
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")

# SQL query
response = await prisma_client.db.litellm_spendlogs.group_by(
by=["api_key", "startTime"],
where={
"startTime": {
"gte": start_date_obj, # Greater than or equal to Start Date
"lte": end_date_obj, # Less than or equal to End Date
}
},
sum={
"spend": True,
},
)

return key_info
# TODO: Execute SQL query and return the results

return {
"message": "This is your SQL query",
"response": response,
}
else:
key_info = await prisma_client.get_data(
table_name="key", query_type="find_all"
)
return key_info

except Exception as e:
raise HTTPException(
Expand Down
3 changes: 3 additions & 0 deletions litellm/tests/test_parallel_request_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ async def test_normal_router_call():
model="azure-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
metadata={"user_api_key": _api_key},
mock_response="hello",
)
await asyncio.sleep(1) # success is done in a separate thread
print(f"response: {response}")
Expand Down Expand Up @@ -450,6 +451,7 @@ async def test_streaming_router_call():
messages=[{"role": "user", "content": "Hey, how's it going?"}],
stream=True,
metadata={"user_api_key": _api_key},
mock_response="hello",
)
async for chunk in response:
continue
Expand Down Expand Up @@ -526,6 +528,7 @@ async def test_streaming_router_tpm_limit():
messages=[{"role": "user", "content": "Write me a paragraph on the moon"}],
stream=True,
metadata={"user_api_key": _api_key},
mock_response="hello",
)
async for chunk in response:
continue
Expand Down
10 changes: 9 additions & 1 deletion litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,7 +1576,7 @@ async def async_success_handler(
# only add to cache once we have a complete streaming response
litellm.cache.add_cache(result, **kwargs)
if isinstance(callback, CustomLogger): # custom logger class
print_verbose(f"Async success callbacks: CustomLogger")
print_verbose(f"Async success callbacks: {callback}")
if self.stream:
if "complete_streaming_response" in self.model_call_details:
await callback.async_log_success_event(
Expand Down Expand Up @@ -8819,6 +8819,14 @@ def mock_completion_streaming_obj(model_response, mock_response, model):
yield model_response


async def async_mock_completion_streaming_obj(model_response, mock_response, model):
for i in range(0, len(mock_response), 3):
completion_obj = Delta(role="assistant", content=mock_response)
model_response.choices[0].delta = completion_obj
model_response.choices[0].finish_reason = "stop"
yield model_response


########## Reading Config File ############################
def read_config_args(config_path) -> dict:
try:
Expand Down
Loading