Skip to content

Commit

Permalink
Merge pull request #1498 from BerriAI/litellm_spend_tracking_logs
Browse files Browse the repository at this point in the history
[Feat] Proxy - Add Spend tracking logs
  • Loading branch information
ishaan-jaff authored Jan 18, 2024
2 parents 2e06e00 + 4294657 commit a262678
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 14 deletions.
24 changes: 21 additions & 3 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pydantic import BaseModel, Extra, Field, root_validator
from pydantic import BaseModel, Extra, Field, root_validator, Json
import enum
from typing import Optional, List, Union, Dict, Literal
from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime
import uuid, json
import uuid, json, sys, os


class LiteLLMBase(BaseModel):
Expand Down Expand Up @@ -196,6 +196,7 @@ class DynamoDBArgs(LiteLLMBase):
user_table_name: str = "LiteLLM_UserTable"
key_table_name: str = "LiteLLM_VerificationToken"
config_table_name: str = "LiteLLM_Config"
spend_table_name: str = "LiteLLM_SpendLogs"


class ConfigGeneralSettings(LiteLLMBase):
Expand Down Expand Up @@ -314,3 +315,20 @@ def set_model_info(cls, values):
if values.get("models") is None:
values.update({"models", []})
return values


class LiteLLM_SpendLogs(LiteLLMBase):
request_id: str
api_key: str
model: Optional[str] = ""
call_type: str
spend: Optional[float] = 0.0
startTime: Union[str, datetime, None]
endTime: Union[str, datetime, None]
user: Optional[str] = ""
modelParameters: Optional[Json] = {}
messages: Optional[Json] = []
response: Optional[Json] = {}
usage: Optional[Json] = {}
metadata: Optional[Json] = {}
cache_hit: Optional[str] = "False"
21 changes: 20 additions & 1 deletion litellm/proxy/db/dynamo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,27 @@ async def connect(self):
raise Exception(
f"Failed to create table - {self.database_arguments.config_table_name}.\nPlease create a new table called {self.database_arguments.config_table_name}\nAND set `hash_key` as 'param_name'"
)

## Spend
try:
verbose_proxy_logger.debug("DynamoDB Wrapper - Creating Spend Table")
error_occurred = False
table = client.table(self.database_arguments.spend_table_name)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("request_id", KeyType.string)),
)
except Exception as e:
error_occurred = True
if error_occurred == True:
raise Exception(
f"Failed to create table - {self.database_arguments.key_table_name}.\nPlease create a new table called {self.database_arguments.key_table_name}\nAND set `hash_key` as 'token'"
)
verbose_proxy_logger.debug("DynamoDB Wrapper - Done connecting()")

async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config"]
self, value: Any, table_name: Literal["user", "key", "config", "spend"]
):
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
Expand Down Expand Up @@ -166,6 +183,8 @@ async def insert_data(
table = client.table(self.database_arguments.key_table_name)
elif table_name == "config":
table = client.table(self.database_arguments.config_table_name)
elif table_name == "spend":
table = client.table(self.database_arguments.spend_table_name)

for k, v in value.items():
if isinstance(v, datetime):
Expand Down
4 changes: 2 additions & 2 deletions litellm/proxy/proxy_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ litellm_settings:
# setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]

# general_settings:
# master_key: sk-1234
general_settings:
master_key: sk-1234
# database_type: "dynamo_db"
# database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190
# "billing_mode": "PAY_PER_REQUEST",
Expand Down
47 changes: 44 additions & 3 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def generate_feedback_box():
ProxyLogging,
_cache_user_row,
send_email,
get_logging_payload,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
import pydantic
Expand Down Expand Up @@ -518,6 +519,7 @@ async def track_cost_callback(
global prisma_client, custom_db_client
try:
# check if it has collected an entire stream response
verbose_proxy_logger.debug(f"Proxy: In track_cost_callback for {kwargs}")
verbose_proxy_logger.debug(
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
)
Expand All @@ -538,7 +540,13 @@ async def track_cost_callback(
prisma_client is not None or custom_db_client is not None
):
await update_database(
token=user_api_key, response_cost=response_cost, user_id=user_id
token=user_api_key,
response_cost=response_cost,
user_id=user_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
)
elif kwargs["stream"] == False: # for non streaming responses
response_cost = litellm.completion_cost(
Expand All @@ -554,13 +562,27 @@ async def track_cost_callback(
prisma_client is not None or custom_db_client is not None
):
await update_database(
token=user_api_key, response_cost=response_cost, user_id=user_id
token=user_api_key,
response_cost=response_cost,
user_id=user_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
)
except Exception as e:
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")


async def update_database(token, response_cost, user_id=None):
async def update_database(
token,
response_cost,
user_id=None,
kwargs=None,
completion_response=None,
start_time=None,
end_time=None,
):
try:
verbose_proxy_logger.debug(
f"Enters prisma db call, token: {token}; user_id: {user_id}"
Expand Down Expand Up @@ -630,9 +652,28 @@ async def _update_key_db():
key=token, value={"spend": new_spend}, table_name="key"
)

async def _insert_spend_log_to_db():
# Helper to generate payload to log
verbose_proxy_logger.debug("inserting spend log to db")
payload = get_logging_payload(
kwargs=kwargs,
response_obj=completion_response,
start_time=start_time,
end_time=end_time,
)

payload["spend"] = response_cost

if prisma_client is not None:
await prisma_client.insert_data(data=payload, table_name="spend")

elif custom_db_client is not None:
await custom_db_client.insert_data(payload, table_name="spend")

tasks = []
tasks.append(_update_user_db())
tasks.append(_update_key_db())
tasks.append(_insert_spend_log_to_db())
await asyncio.gather(*tasks)
except Exception as e:
verbose_proxy_logger.debug(
Expand Down
17 changes: 17 additions & 0 deletions litellm/proxy/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,21 @@ model LiteLLM_VerificationToken {
model LiteLLM_Config {
param_name String @id
param_value Json?
}

model LiteLLM_SpendLogs {
request_id String @unique
call_type String
api_key String @default ("")
spend Float @default(0.0)
startTime DateTime // Assuming start_time is a DateTime field
endTime DateTime // Assuming end_time is a DateTime field
model String @default("")
user String @default("")
modelParameters Json @default("{}")// Assuming optional_params is a JSON field
messages Json @default("[]")
response Json @default("{}")
usage Json @default("{}")
metadata Json @default("{}")
cache_hit String @default("")
}
109 changes: 105 additions & 4 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs, LiteLLM_VerificationToken
from litellm.proxy._types import (
UserAPIKeyAuth,
DynamoDBArgs,
LiteLLM_VerificationToken,
LiteLLM_SpendLogs,
)
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
Expand Down Expand Up @@ -316,7 +321,7 @@ async def get_generic_data(
self,
key: str,
value: Any,
table_name: Literal["users", "keys", "config"],
table_name: Literal["users", "keys", "config", "spend"],
):
"""
Generic implementation of get data
Expand All @@ -334,6 +339,10 @@ async def get_generic_data(
response = await self.db.litellm_config.find_first( # type: ignore
where={key: value} # type: ignore
)
elif table_name == "spend":
response = await self.db.l.find_first( # type: ignore
where={key: value} # type: ignore
)
return response
except Exception as e:
asyncio.create_task(
Expand Down Expand Up @@ -417,7 +426,7 @@ async def get_data(
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def insert_data(
self, data: dict, table_name: Literal["user", "key", "config"]
self, data: dict, table_name: Literal["user", "key", "config", "spend"]
):
"""
Add a key to the database. If it already exists, do nothing.
Expand Down Expand Up @@ -473,8 +482,18 @@ async def insert_data(
)

tasks.append(updated_table_row)

await asyncio.gather(*tasks)
elif table_name == "spend":
db_data = self.jsonify_object(data=data)
new_spend_row = await self.db.litellm_spendlogs.upsert(
where={"request_id": data["request_id"]},
data={
"create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists
},
)
return new_spend_row

except Exception as e:
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
asyncio.create_task(
Expand Down Expand Up @@ -760,3 +779,85 @@ async def send_email(sender_name, sender_email, receiver_email, subject, html):

except Exception as e:
print_verbose("An error occurred while sending the email:", str(e))


def hash_token(token: str):
import hashlib

# Hash the string using SHA-256
hashed_token = hashlib.sha256(token.encode()).hexdigest()

return hashed_token


def get_logging_payload(kwargs, response_obj, start_time, end_time):
from litellm.proxy._types import LiteLLM_SpendLogs
from pydantic import Json
import uuid

if kwargs == None:
kwargs = {}
# standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
messages = kwargs.get("messages")
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion")
cache_hit = kwargs.get("cache_hit", False)
usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4()))
api_key = metadata.get("user_api_key", "")
if api_key is not None and type(api_key) == str:
# hash the api_key
api_key = hash_token(api_key)

payload = {
"request_id": id,
"call_type": call_type,
"api_key": api_key,
"cache_hit": cache_hit,
"startTime": start_time,
"endTime": end_time,
"model": kwargs.get("model", ""),
"user": kwargs.get("user", ""),
"modelParameters": optional_params,
"messages": messages,
"response": response_obj,
"usage": usage,
"metadata": metadata,
}

json_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == Json or field_type == Optional[Json]
]
str_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == str or field_type == Optional[str]
]
datetime_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == datetime
]

for param in json_fields:
if param in payload and type(payload[param]) != Json:
if type(payload[param]) == litellm.ModelResponse:
payload[param] = payload[param].model_dump_json()
if type(payload[param]) == litellm.EmbeddingResponse:
payload[param] = payload[param].model_dump_json()
elif type(payload[param]) == litellm.Usage:
payload[param] = payload[param].model_dump_json()
else:
payload[param] = json.dumps(payload[param])

for param in str_fields:
if param in payload and type(payload[param]) != str:
payload[param] = str(payload[param])

return payload
4 changes: 4 additions & 0 deletions litellm/tests/test_key_generate_dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ def test_call_with_key_over_budget(custom_db_client):
# 5. Make a call with a key over budget, expect to fail
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
from litellm._logging import verbose_proxy_logger
import logging

verbose_proxy_logger.setLevel(logging.DEBUG)
try:

async def test():
Expand Down
19 changes: 18 additions & 1 deletion schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,21 @@ model LiteLLM_VerificationToken {
model LiteLLM_Config {
param_name String @id
param_value Json?
}
}

model LiteLLM_SpendLogs {
request_id String @unique
api_key String @default ("")
call_type String
spend Float @default(0.0)
startTime DateTime // Assuming start_time is a DateTime field
endTime DateTime // Assuming end_time is a DateTime field
model String @default("")
user String @default("")
modelParameters Json @default("{}")// Assuming optional_params is a JSON field
messages Json @default("[]")
response Json @default("{}")
usage Json @default("{}")
metadata Json @default("{}")
cache_hit String @default("")
}

0 comments on commit a262678

Please sign in to comment.