Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add cost tracking + caching for /audio/transcription calls #2426

Merged
merged 6 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 73 additions & 9 deletions litellm/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import litellm
import time, logging, asyncio
import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any
from typing import Optional, Literal, List, Union, Any, BinaryIO
from openai._models import BaseModel as OpenAIObject
from litellm._logging import verbose_logger

Expand Down Expand Up @@ -764,8 +764,24 @@ def __init__(
password: Optional[str] = None,
similarity_threshold: Optional[float] = None,
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
List[
Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
# s3 Bucket, boto3 configuration
s3_bucket_name: Optional[str] = None,
s3_region_name: Optional[str] = None,
Expand Down Expand Up @@ -880,9 +896,14 @@ def get_cache_key(self, *args, **kwargs):
"input",
"encoding_format",
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs

transcription_only_kwargs = [
"file",
"language",
]
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
combined_kwargs = completion_kwargs + embedding_only_kwargs
combined_kwargs = (
completion_kwargs + embedding_only_kwargs + transcription_only_kwargs
)
for param in combined_kwargs:
# ignore litellm params here
if param in kwargs:
Expand Down Expand Up @@ -914,6 +935,17 @@ def get_cache_key(self, *args, **kwargs):
param_value = (
caching_group or model_group or kwargs[param]
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
elif param == "file":
metadata_file_name = kwargs.get("metadata", {}).get(
"file_name", None
)
litellm_params_file_name = kwargs.get("litellm_params", {}).get(
"file_name", None
)
if metadata_file_name is not None:
param_value = metadata_file_name
elif litellm_params_file_name is not None:
param_value = litellm_params_file_name
else:
if kwargs[param] is None:
continue # ignore None params
Expand Down Expand Up @@ -1143,8 +1175,24 @@ def enable_cache(
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
List[
Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
**kwargs,
):
"""
Expand Down Expand Up @@ -1192,8 +1240,24 @@ def update_cache(
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
List[
Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
**kwargs,
):
"""
Expand Down
6 changes: 4 additions & 2 deletions litellm/llms/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,8 @@ def audio_transcriptions(
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response

async def async_audio_transcriptions(
Expand Down Expand Up @@ -921,7 +922,8 @@ async def async_audio_transcriptions(
},
original_response=stringified_response,
)
response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return response
except Exception as e:
## LOGGING
Expand Down
7 changes: 5 additions & 2 deletions litellm/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,7 @@ def image_generation(
# return response
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e:

exception_mapping_worked = True
## LOGGING
logging_obj.post_call(
Expand Down Expand Up @@ -824,7 +825,8 @@ def audio_transcriptions(
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response

async def async_audio_transcriptions(
Expand Down Expand Up @@ -862,7 +864,8 @@ async def async_audio_transcriptions(
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
Expand Down
3 changes: 2 additions & 1 deletion litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3282,6 +3282,7 @@ async def audio_transcriptions(
user_api_key_dict, "team_id", None
)
data["metadata"]["endpoint"] = str(request.url)
data["metadata"]["file_name"] = file.filename

### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
Expand Down Expand Up @@ -3316,7 +3317,7 @@ async def audio_transcriptions(
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict,
data=data,
call_type="moderation",
call_type="audio_transcription",
)

## ROUTE TO CORRECT ENDPOINT ##
Expand Down
6 changes: 5 additions & 1 deletion litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ async def pre_call_hook(
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal[
"completion", "embeddings", "image_generation", "moderation"
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
"""
Expand Down
61 changes: 60 additions & 1 deletion litellm/tests/test_completion_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
) # Adds the parent directory to the system path
import time
import litellm
from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models
from litellm import (
get_max_tokens,
model_cost,
open_ai_chat_completion_models,
TranscriptionResponse,
)
import pytest


Expand Down Expand Up @@ -238,3 +243,57 @@ def test_cost_bedrock_pricing_actual_calls():
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
assert cost > 0


def test_whisper_openai():
litellm.set_verbose = True
transcription = TranscriptionResponse(
text="Four score and seven years ago, our fathers brought forth on this continent a new nation, conceived in liberty and dedicated to the proposition that all men are created equal. Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure."
)
transcription._hidden_params = {
"model": "whisper-1",
"custom_llm_provider": "openai",
"optional_params": {},
"model_id": None,
}
_total_time_in_seconds = 3

transcription._response_ms = _total_time_in_seconds * 1000
cost = litellm.completion_cost(model="whisper-1", completion_response=transcription)

print(f"cost: {cost}")
print(f"whisper dict: {litellm.model_cost['whisper-1']}")
expected_cost = round(
litellm.model_cost["whisper-1"]["output_cost_per_second"]
* _total_time_in_seconds,
5,
)
assert cost == expected_cost


def test_whisper_azure():
litellm.set_verbose = True
transcription = TranscriptionResponse(
text="Four score and seven years ago, our fathers brought forth on this continent a new nation, conceived in liberty and dedicated to the proposition that all men are created equal. Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure."
)
transcription._hidden_params = {
"model": "whisper-1",
"custom_llm_provider": "azure",
"optional_params": {},
"model_id": None,
}
_total_time_in_seconds = 3

transcription._response_ms = _total_time_in_seconds * 1000
cost = litellm.completion_cost(
model="azure/azure-whisper", completion_response=transcription
)

print(f"cost: {cost}")
print(f"whisper dict: {litellm.model_cost['whisper-1']}")
expected_cost = round(
litellm.model_cost["whisper-1"]["output_cost_per_second"]
* _total_time_in_seconds,
5,
)
assert cost == expected_cost
1 change: 1 addition & 0 deletions litellm/tests/test_custom_callback_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ def test_image_generation_openai():

print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
time.sleep(2)
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
Expand Down
2 changes: 1 addition & 1 deletion litellm/tests/test_custom_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def async_test_logging_fn(self, kwargs, completion_obj, start_time, end_ti
def test_async_chat_openai_stream():
try:
tmp_function = TmpFunction()
# litellm.set_verbose = True
litellm.set_verbose = True
litellm.success_callback = [tmp_function.async_test_logging_fn]
complete_streaming_response = ""

Expand Down
2 changes: 2 additions & 0 deletions litellm/tests/test_proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ def test_load_router_config():
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
] # init with all call types

litellm.disable_cache()
Expand Down
Loading
Loading