diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index ad90173a44e6..e01f6f3a3316 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1,8 +1,8 @@ -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import BaseModel, Extra, Field, root_validator, Json import enum -from typing import Optional, List, Union, Dict, Literal +from typing import Optional, List, Union, Dict, Literal, Any from datetime import datetime -import uuid, json +import uuid, json, sys, os class LiteLLMBase(BaseModel): @@ -196,6 +196,7 @@ class DynamoDBArgs(LiteLLMBase): user_table_name: str = "LiteLLM_UserTable" key_table_name: str = "LiteLLM_VerificationToken" config_table_name: str = "LiteLLM_Config" + spend_table_name: str = "LiteLLM_SpendLogs" class ConfigGeneralSettings(LiteLLMBase): @@ -314,3 +315,20 @@ def set_model_info(cls, values): if values.get("models") is None: values.update({"models", []}) return values + + +class LiteLLM_SpendLogs(LiteLLMBase): + request_id: str + api_key: str + model: Optional[str] = "" + call_type: str + spend: Optional[float] = 0.0 + startTime: Union[str, datetime, None] + endTime: Union[str, datetime, None] + user: Optional[str] = "" + modelParameters: Optional[Json] = {} + messages: Optional[Json] = [] + response: Optional[Json] = {} + usage: Optional[Json] = {} + metadata: Optional[Json] = {} + cache_hit: Optional[str] = "False" diff --git a/litellm/proxy/db/dynamo_db.py b/litellm/proxy/db/dynamo_db.py index eb1c0852861d..83cf6b157246 100644 --- a/litellm/proxy/db/dynamo_db.py +++ b/litellm/proxy/db/dynamo_db.py @@ -131,10 +131,27 @@ async def connect(self): raise Exception( f"Failed to create table - {self.database_arguments.config_table_name}.\nPlease create a new table called {self.database_arguments.config_table_name}\nAND set `hash_key` as 'param_name'" ) + + ## Spend + try: + verbose_proxy_logger.debug("DynamoDB Wrapper - Creating Spend Table") + error_occurred = False + table = client.table(self.database_arguments.spend_table_name) + if not await table.exists(): + await table.create( + self.throughput_type, + KeySchema(hash_key=KeySpec("request_id", KeyType.string)), + ) + except Exception as e: + error_occurred = True + if error_occurred == True: + raise Exception( + f"Failed to create table - {self.database_arguments.key_table_name}.\nPlease create a new table called {self.database_arguments.key_table_name}\nAND set `hash_key` as 'token'" + ) verbose_proxy_logger.debug("DynamoDB Wrapper - Done connecting()") async def insert_data( - self, value: Any, table_name: Literal["user", "key", "config"] + self, value: Any, table_name: Literal["user", "key", "config", "spend"] ): from aiodynamo.client import Client from aiodynamo.credentials import Credentials, StaticCredentials @@ -166,6 +183,8 @@ async def insert_data( table = client.table(self.database_arguments.key_table_name) elif table_name == "config": table = client.table(self.database_arguments.config_table_name) + elif table_name == "spend": + table = client.table(self.database_arguments.spend_table_name) for k, v in value.items(): if isinstance(v, datetime): diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 5b87ab775b2a..8cd2fcec85fe 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -61,8 +61,8 @@ litellm_settings: # setting callback class # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] -# general_settings: - # master_key: sk-1234 +general_settings: + master_key: sk-1234 # database_type: "dynamo_db" # database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190 # "billing_mode": "PAY_PER_REQUEST", diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 10c968b1c5e2..afa0f0fe007a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -72,6 +72,7 @@ def generate_feedback_box(): ProxyLogging, _cache_user_row, send_email, + get_logging_payload, ) from litellm.proxy.secret_managers.google_kms import load_google_kms import pydantic @@ -518,6 +519,7 @@ async def track_cost_callback( global prisma_client, custom_db_client try: # check if it has collected an entire stream response + verbose_proxy_logger.debug(f"Proxy: In track_cost_callback for {kwargs}") verbose_proxy_logger.debug( f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" ) @@ -538,7 +540,13 @@ async def track_cost_callback( prisma_client is not None or custom_db_client is not None ): await update_database( - token=user_api_key, response_cost=response_cost, user_id=user_id + token=user_api_key, + response_cost=response_cost, + user_id=user_id, + kwargs=kwargs, + completion_response=completion_response, + start_time=start_time, + end_time=end_time, ) elif kwargs["stream"] == False: # for non streaming responses response_cost = litellm.completion_cost( @@ -554,13 +562,27 @@ async def track_cost_callback( prisma_client is not None or custom_db_client is not None ): await update_database( - token=user_api_key, response_cost=response_cost, user_id=user_id + token=user_api_key, + response_cost=response_cost, + user_id=user_id, + kwargs=kwargs, + completion_response=completion_response, + start_time=start_time, + end_time=end_time, ) except Exception as e: verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") -async def update_database(token, response_cost, user_id=None): +async def update_database( + token, + response_cost, + user_id=None, + kwargs=None, + completion_response=None, + start_time=None, + end_time=None, +): try: verbose_proxy_logger.debug( f"Enters prisma db call, token: {token}; user_id: {user_id}" @@ -630,9 +652,28 @@ async def _update_key_db(): key=token, value={"spend": new_spend}, table_name="key" ) + async def _insert_spend_log_to_db(): + # Helper to generate payload to log + verbose_proxy_logger.debug("inserting spend log to db") + payload = get_logging_payload( + kwargs=kwargs, + response_obj=completion_response, + start_time=start_time, + end_time=end_time, + ) + + payload["spend"] = response_cost + + if prisma_client is not None: + await prisma_client.insert_data(data=payload, table_name="spend") + + elif custom_db_client is not None: + await custom_db_client.insert_data(payload, table_name="spend") + tasks = [] tasks.append(_update_user_db()) tasks.append(_update_key_db()) + tasks.append(_insert_spend_log_to_db()) await asyncio.gather(*tasks) except Exception as e: verbose_proxy_logger.debug( diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index aa45a8818658..2e40a32045dd 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -31,4 +31,21 @@ model LiteLLM_VerificationToken { model LiteLLM_Config { param_name String @id param_value Json? +} + +model LiteLLM_SpendLogs { + request_id String @unique + call_type String + api_key String @default ("") + spend Float @default(0.0) + startTime DateTime // Assuming start_time is a DateTime field + endTime DateTime // Assuming end_time is a DateTime field + model String @default("") + user String @default("") + modelParameters Json @default("{}")// Assuming optional_params is a JSON field + messages Json @default("[]") + response Json @default("{}") + usage Json @default("{}") + metadata Json @default("{}") + cache_hit String @default("") } \ No newline at end of file diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index ab1fea463eef..23b66f22d796 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,7 +1,12 @@ from typing import Optional, List, Any, Literal, Union import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx import litellm, backoff -from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs, LiteLLM_VerificationToken +from litellm.proxy._types import ( + UserAPIKeyAuth, + DynamoDBArgs, + LiteLLM_VerificationToken, + LiteLLM_SpendLogs, +) from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter @@ -316,7 +321,7 @@ async def get_generic_data( self, key: str, value: Any, - table_name: Literal["users", "keys", "config"], + table_name: Literal["users", "keys", "config", "spend"], ): """ Generic implementation of get data @@ -334,6 +339,10 @@ async def get_generic_data( response = await self.db.litellm_config.find_first( # type: ignore where={key: value} # type: ignore ) + elif table_name == "spend": + response = await self.db.l.find_first( # type: ignore + where={key: value} # type: ignore + ) return response except Exception as e: asyncio.create_task( @@ -417,7 +426,7 @@ async def get_data( on_backoff=on_backoff, # specifying the function to call on backoff ) async def insert_data( - self, data: dict, table_name: Literal["user", "key", "config"] + self, data: dict, table_name: Literal["user", "key", "config", "spend"] ): """ Add a key to the database. If it already exists, do nothing. @@ -473,8 +482,18 @@ async def insert_data( ) tasks.append(updated_table_row) - await asyncio.gather(*tasks) + elif table_name == "spend": + db_data = self.jsonify_object(data=data) + new_spend_row = await self.db.litellm_spendlogs.upsert( + where={"request_id": data["request_id"]}, + data={ + "create": {**db_data}, # type: ignore + "update": {}, # don't do anything if it already exists + }, + ) + return new_spend_row + except Exception as e: print_verbose(f"LiteLLM Prisma Client Exception: {e}") asyncio.create_task( @@ -760,3 +779,85 @@ async def send_email(sender_name, sender_email, receiver_email, subject, html): except Exception as e: print_verbose("An error occurred while sending the email:", str(e)) + + +def hash_token(token: str): + import hashlib + + # Hash the string using SHA-256 + hashed_token = hashlib.sha256(token.encode()).hexdigest() + + return hashed_token + + +def get_logging_payload(kwargs, response_obj, start_time, end_time): + from litellm.proxy._types import LiteLLM_SpendLogs + from pydantic import Json + import uuid + + if kwargs == None: + kwargs = {} + # standardize this function to be used across, s3, dynamoDB, langfuse logging + litellm_params = kwargs.get("litellm_params", {}) + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None + messages = kwargs.get("messages") + optional_params = kwargs.get("optional_params", {}) + call_type = kwargs.get("call_type", "litellm.completion") + cache_hit = kwargs.get("cache_hit", False) + usage = response_obj["usage"] + id = response_obj.get("id", str(uuid.uuid4())) + api_key = metadata.get("user_api_key", "") + if api_key is not None and type(api_key) == str: + # hash the api_key + api_key = hash_token(api_key) + + payload = { + "request_id": id, + "call_type": call_type, + "api_key": api_key, + "cache_hit": cache_hit, + "startTime": start_time, + "endTime": end_time, + "model": kwargs.get("model", ""), + "user": kwargs.get("user", ""), + "modelParameters": optional_params, + "messages": messages, + "response": response_obj, + "usage": usage, + "metadata": metadata, + } + + json_fields = [ + field + for field, field_type in LiteLLM_SpendLogs.__annotations__.items() + if field_type == Json or field_type == Optional[Json] + ] + str_fields = [ + field + for field, field_type in LiteLLM_SpendLogs.__annotations__.items() + if field_type == str or field_type == Optional[str] + ] + datetime_fields = [ + field + for field, field_type in LiteLLM_SpendLogs.__annotations__.items() + if field_type == datetime + ] + + for param in json_fields: + if param in payload and type(payload[param]) != Json: + if type(payload[param]) == litellm.ModelResponse: + payload[param] = payload[param].model_dump_json() + if type(payload[param]) == litellm.EmbeddingResponse: + payload[param] = payload[param].model_dump_json() + elif type(payload[param]) == litellm.Usage: + payload[param] = payload[param].model_dump_json() + else: + payload[param] = json.dumps(payload[param]) + + for param in str_fields: + if param in payload and type(payload[param]) != str: + payload[param] = str(payload[param]) + + return payload diff --git a/litellm/tests/test_key_generate_dynamodb.py b/litellm/tests/test_key_generate_dynamodb.py index 09f699af7d7e..2cfa9c95312c 100644 --- a/litellm/tests/test_key_generate_dynamodb.py +++ b/litellm/tests/test_key_generate_dynamodb.py @@ -179,6 +179,10 @@ def test_call_with_key_over_budget(custom_db_client): # 5. Make a call with a key over budget, expect to fail setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client) setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + from litellm._logging import verbose_proxy_logger + import logging + + verbose_proxy_logger.setLevel(logging.DEBUG) try: async def test(): diff --git a/schema.prisma b/schema.prisma index 704ada42c980..31eae05c2ee8 100644 --- a/schema.prisma +++ b/schema.prisma @@ -31,4 +31,21 @@ model LiteLLM_VerificationToken { model LiteLLM_Config { param_name String @id param_value Json? -} \ No newline at end of file +} + +model LiteLLM_SpendLogs { + request_id String @unique + api_key String @default ("") + call_type String + spend Float @default(0.0) + startTime DateTime // Assuming start_time is a DateTime field + endTime DateTime // Assuming end_time is a DateTime field + model String @default("") + user String @default("") + modelParameters Json @default("{}")// Assuming optional_params is a JSON field + messages Json @default("[]") + response Json @default("{}") + usage Json @default("{}") + metadata Json @default("{}") + cache_hit String @default("") +}