From fa45c569fd632ca9c25e9be1d9541a53430a2cd8 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 9 Mar 2024 15:43:38 -0800 Subject: [PATCH 1/6] feat: add cost tracking + caching for transcription calls --- litellm/caching.py | 86 ++++++++++++++++++++++--- litellm/llms/azure.py | 6 +- litellm/llms/openai.py | 6 +- litellm/proxy/proxy_server.py | 3 +- litellm/proxy/utils.py | 6 +- litellm/tests/test_completion_cost.py | 61 +++++++++++++++++- litellm/utils.py | 90 +++++++++++++++++++++------ tests/test_whisper.py | 4 +- 8 files changed, 225 insertions(+), 37 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index bf11f4c3979c..623b483e88e1 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -10,7 +10,7 @@ import litellm import time, logging, asyncio import json, traceback, ast, hashlib -from typing import Optional, Literal, List, Union, Any +from typing import Optional, Literal, List, Union, Any, BinaryIO from openai._models import BaseModel as OpenAIObject from litellm._logging import verbose_logger @@ -764,8 +764,24 @@ def __init__( password: Optional[str] = None, similarity_threshold: Optional[float] = None, supported_call_types: Optional[ - List[Literal["completion", "acompletion", "embedding", "aembedding"]] - ] = ["completion", "acompletion", "embedding", "aembedding"], + List[ + Literal[ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + ] + ] + ] = [ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + ], # s3 Bucket, boto3 configuration s3_bucket_name: Optional[str] = None, s3_region_name: Optional[str] = None, @@ -880,9 +896,18 @@ def get_cache_key(self, *args, **kwargs): "input", "encoding_format", ] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs - + transcription_only_kwargs = [ + "model", + "file", + "language", + "prompt", + "response_format", + "temperature", + ] # combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set() - combined_kwargs = completion_kwargs + embedding_only_kwargs + combined_kwargs = ( + completion_kwargs + embedding_only_kwargs + transcription_only_kwargs + ) for param in combined_kwargs: # ignore litellm params here if param in kwargs: @@ -914,6 +939,17 @@ def get_cache_key(self, *args, **kwargs): param_value = ( caching_group or model_group or kwargs[param] ) # use caching_group, if set then model_group if it exists, else use kwargs["model"] + elif param == "file": + metadata_file_name = kwargs.get("metadata", {}).get( + "file_name", None + ) + litellm_params_file_name = kwargs.get("litellm_params", {}).get( + "file_name", None + ) + if metadata_file_name is not None: + param_value = metadata_file_name + elif litellm_params_file_name is not None: + param_value = litellm_params_file_name else: if kwargs[param] is None: continue # ignore None params @@ -1143,8 +1179,24 @@ def enable_cache( port: Optional[str] = None, password: Optional[str] = None, supported_call_types: Optional[ - List[Literal["completion", "acompletion", "embedding", "aembedding"]] - ] = ["completion", "acompletion", "embedding", "aembedding"], + List[ + Literal[ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + ] + ] + ] = [ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + ], **kwargs, ): """ @@ -1192,8 +1244,24 @@ def update_cache( port: Optional[str] = None, password: Optional[str] = None, supported_call_types: Optional[ - List[Literal["completion", "acompletion", "embedding", "aembedding"]] - ] = ["completion", "acompletion", "embedding", "aembedding"], + List[ + Literal[ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + ] + ] + ] = [ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + ], **kwargs, ): """ diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 5fc0939bbc9a..0c8c7f184466 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -861,7 +861,8 @@ def audio_transcriptions( additional_args={"complete_input_dict": data}, original_response=stringified_response, ) - final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore + hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"} + final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore return final_response async def async_audio_transcriptions( @@ -921,7 +922,8 @@ async def async_audio_transcriptions( }, original_response=stringified_response, ) - response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore + hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"} + response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore return response except Exception as e: ## LOGGING diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 9850cd61eb53..3f4b59d4586b 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -824,7 +824,8 @@ def audio_transcriptions( additional_args={"complete_input_dict": data}, original_response=stringified_response, ) - final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore + hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"} + final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore return final_response async def async_audio_transcriptions( @@ -862,7 +863,8 @@ async def async_audio_transcriptions( additional_args={"complete_input_dict": data}, original_response=stringified_response, ) - return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore + hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"} + return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore except Exception as e: ## LOGGING logging_obj.post_call( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8b1db959c40d..720de8745d66 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3282,6 +3282,7 @@ async def audio_transcriptions( user_api_key_dict, "team_id", None ) data["metadata"]["endpoint"] = str(request.url) + data["metadata"]["file_name"] = file.filename ### TEAM-SPECIFIC PARAMS ### if user_api_key_dict.team_id is not None: @@ -3316,7 +3317,7 @@ async def audio_transcriptions( data = await proxy_logging_obj.pre_call_hook( user_api_key_dict=user_api_key_dict, data=data, - call_type="moderation", + call_type="audio_transcription", ) ## ROUTE TO CORRECT ENDPOINT ## diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 89976ff0d8c5..270b53647b36 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -96,7 +96,11 @@ async def pre_call_hook( user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal[ - "completion", "embeddings", "image_generation", "moderation" + "completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", ], ): """ diff --git a/litellm/tests/test_completion_cost.py b/litellm/tests/test_completion_cost.py index 947da71669b3..16ec0602d4ab 100644 --- a/litellm/tests/test_completion_cost.py +++ b/litellm/tests/test_completion_cost.py @@ -6,7 +6,12 @@ ) # Adds the parent directory to the system path import time import litellm -from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models +from litellm import ( + get_max_tokens, + model_cost, + open_ai_chat_completion_models, + TranscriptionResponse, +) import pytest @@ -238,3 +243,57 @@ def test_cost_bedrock_pricing_actual_calls(): messages=[{"role": "user", "content": "Hey, how's it going?"}], ) assert cost > 0 + + +def test_whisper_openai(): + litellm.set_verbose = True + transcription = TranscriptionResponse( + text="Four score and seven years ago, our fathers brought forth on this continent a new nation, conceived in liberty and dedicated to the proposition that all men are created equal. Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure." + ) + transcription._hidden_params = { + "model": "whisper-1", + "custom_llm_provider": "openai", + "optional_params": {}, + "model_id": None, + } + _total_time_in_seconds = 3 + + transcription._response_ms = _total_time_in_seconds * 1000 + cost = litellm.completion_cost(model="whisper-1", completion_response=transcription) + + print(f"cost: {cost}") + print(f"whisper dict: {litellm.model_cost['whisper-1']}") + expected_cost = round( + litellm.model_cost["whisper-1"]["output_cost_per_second"] + * _total_time_in_seconds, + 5, + ) + assert cost == expected_cost + + +def test_whisper_azure(): + litellm.set_verbose = True + transcription = TranscriptionResponse( + text="Four score and seven years ago, our fathers brought forth on this continent a new nation, conceived in liberty and dedicated to the proposition that all men are created equal. Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure." + ) + transcription._hidden_params = { + "model": "whisper-1", + "custom_llm_provider": "azure", + "optional_params": {}, + "model_id": None, + } + _total_time_in_seconds = 3 + + transcription._response_ms = _total_time_in_seconds * 1000 + cost = litellm.completion_cost( + model="azure/azure-whisper", completion_response=transcription + ) + + print(f"cost: {cost}") + print(f"whisper dict: {litellm.model_cost['whisper-1']}") + expected_cost = round( + litellm.model_cost["whisper-1"]["output_cost_per_second"] + * _total_time_in_seconds, + 5, + ) + assert cost == expected_cost diff --git a/litellm/utils.py b/litellm/utils.py index 7466bd5c6947..9568bf3f9dda 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1168,6 +1168,7 @@ def _success_handler_helper_fn( isinstance(result, ModelResponse) or isinstance(result, EmbeddingResponse) or isinstance(result, ImageResponse) + or isinstance(result, TranscriptionResponse) ) and self.stream != True ): # handle streaming separately @@ -1203,9 +1204,6 @@ def _success_handler_helper_fn( model=base_model, ) ) - verbose_logger.debug( - f"Model={self.model}; cost={self.model_call_details['response_cost']}" - ) except litellm.NotFoundError as e: verbose_logger.debug( f"Model={self.model} not found in completion cost map." @@ -1236,7 +1234,7 @@ def _success_handler_helper_fn( def success_handler( self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs ): - verbose_logger.debug(f"Logging Details LiteLLM-Success Call: {cache_hit}") + print_verbose(f"Logging Details LiteLLM-Success Call: {cache_hit}") start_time, end_time, result = self._success_handler_helper_fn( start_time=start_time, end_time=end_time, @@ -1681,6 +1679,7 @@ async def async_success_handler( """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ + print_verbose(f"Logging Details LiteLLM-Async Success Call: {cache_hit}") start_time, end_time, result = self._success_handler_helper_fn( start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit ) @@ -2473,6 +2472,7 @@ def wrapper(*args, **kwargs): and kwargs.get("aembedding", False) != True and kwargs.get("acompletion", False) != True and kwargs.get("aimg_generation", False) != True + and kwargs.get("atranscription", False) != True ): # allow users to control returning cached responses from the completion function # checking cache print_verbose(f"INSIDE CHECKING CACHE") @@ -2875,6 +2875,19 @@ async def wrapper_async(*args, **kwargs): model_response_object=EmbeddingResponse(), response_type="embedding", ) + elif call_type == CallTypes.atranscription.value and isinstance( + cached_result, dict + ): + hidden_params = { + "model": "whisper-1", + "custom_llm_provider": custom_llm_provider, + } + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=TranscriptionResponse(), + response_type="audio_transcription", + hidden_params=hidden_params, + ) if kwargs.get("stream", False) == False: # LOG SUCCESS asyncio.create_task( @@ -3001,6 +3014,20 @@ async def wrapper_async(*args, **kwargs): else: return result + # ADD HIDDEN PARAMS - additional call metadata + if hasattr(result, "_hidden_params"): + result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( + "id", None + ) + if ( + isinstance(result, ModelResponse) + or isinstance(result, EmbeddingResponse) + or isinstance(result, TranscriptionResponse) + ): + result._response_ms = ( + end_time - start_time + ).total_seconds() * 1000 # return response latency in ms like openai + ### POST-CALL RULES ### post_call_processing(original_response=result, model=model) @@ -3013,8 +3040,10 @@ async def wrapper_async(*args, **kwargs): ) and (kwargs.get("cache", {}).get("no-store", False) != True) ): - if isinstance(result, litellm.ModelResponse) or isinstance( - result, litellm.EmbeddingResponse + if ( + isinstance(result, litellm.ModelResponse) + or isinstance(result, litellm.EmbeddingResponse) + or isinstance(result, TranscriptionResponse) ): if ( isinstance(result, EmbeddingResponse) @@ -3058,18 +3087,7 @@ async def wrapper_async(*args, **kwargs): args=(result, start_time, end_time), ).start() - # RETURN RESULT - if hasattr(result, "_hidden_params"): - result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( - "id", None - ) - if isinstance(result, ModelResponse) or isinstance( - result, EmbeddingResponse - ): - result._response_ms = ( - end_time - start_time - ).total_seconds() * 1000 # return response latency in ms like openai - + # REBUILD EMBEDDING CACHING if ( isinstance(result, EmbeddingResponse) and final_embedding_cached_response is not None @@ -3575,6 +3593,20 @@ def cost_per_token( completion_tokens_cost_usd_dollar = ( model_cost_ref[model]["output_cost_per_token"] * completion_tokens ) + elif ( + model_cost_ref[model].get("output_cost_per_second", None) is not None + and response_time_ms is not None + ): + print_verbose( + f"For model={model} - output_cost_per_second: {model_cost_ref[model].get('output_cost_per_second')}; response time: {response_time_ms}" + ) + ## COST PER SECOND ## + prompt_tokens_cost_usd_dollar = 0 + completion_tokens_cost_usd_dollar = ( + model_cost_ref[model]["output_cost_per_second"] + * response_time_ms + / 1000 + ) elif ( model_cost_ref[model].get("input_cost_per_second", None) is not None and response_time_ms is not None @@ -3659,6 +3691,8 @@ def completion_cost( "text_completion", "image_generation", "aimage_generation", + "transcription", + "atranscription", ] = "completion", ### REGION ### custom_llm_provider=None, @@ -3703,6 +3737,7 @@ def completion_cost( and custom_llm_provider == "azure" ): model = "dall-e-2" # for dall-e-2, azure expects an empty model name + # Handle Inputs to completion_cost prompt_tokens = 0 completion_tokens = 0 @@ -3717,10 +3752,11 @@ def completion_cost( verbose_logger.debug( f"completion_response response ms: {completion_response.get('_response_ms')} " ) - model = ( - model or completion_response["model"] + model = model or completion_response.get( + "model", None ) # check if user passed an override for model, if it's none check completion_response['model'] if hasattr(completion_response, "_hidden_params"): + model = completion_response._hidden_params.get("model", model) custom_llm_provider = completion_response._hidden_params.get( "custom_llm_provider", "" ) @@ -3801,6 +3837,7 @@ def completion_cost( # see https://replicate.com/pricing elif model in litellm.replicate_models or "replicate" in model: return get_replicate_completion_pricing(completion_response, total_time) + ( prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar, @@ -6314,6 +6351,7 @@ def convert_to_model_response_object( stream=False, start_time=None, end_time=None, + hidden_params: Optional[dict] = None, ): try: if response_type == "completion" and ( @@ -6373,6 +6411,9 @@ def convert_to_model_response_object( end_time - start_time ).total_seconds() * 1000 + if hidden_params is not None: + model_response_object._hidden_params = hidden_params + return model_response_object elif response_type == "embedding" and ( model_response_object is None @@ -6402,6 +6443,9 @@ def convert_to_model_response_object( end_time - start_time ).total_seconds() * 1000 # return response latency in ms like openai + if hidden_params is not None: + model_response_object._hidden_params = hidden_params + return model_response_object elif response_type == "image_generation" and ( model_response_object is None @@ -6419,6 +6463,9 @@ def convert_to_model_response_object( if "data" in response_object: model_response_object.data = response_object["data"] + if hidden_params is not None: + model_response_object._hidden_params = hidden_params + return model_response_object elif response_type == "audio_transcription" and ( model_response_object is None @@ -6432,6 +6479,9 @@ def convert_to_model_response_object( if "text" in response_object: model_response_object.text = response_object["text"] + + if hidden_params is not None: + model_response_object._hidden_params = hidden_params return model_response_object except Exception as e: raise Exception(f"Invalid response object {traceback.format_exc()}") diff --git a/tests/test_whisper.py b/tests/test_whisper.py index 54ecfbf50c37..1debbbc1db38 100644 --- a/tests/test_whisper.py +++ b/tests/test_whisper.py @@ -31,7 +31,8 @@ def test_transcription(): model="whisper-1", file=audio_file, ) - print(f"transcript: {transcript}") + print(f"transcript: {transcript.model_dump()}") + print(f"transcript: {transcript._hidden_params}") # test_transcription() @@ -47,6 +48,7 @@ def test_transcription_azure(): api_version="2024-02-15-preview", ) + print(f"transcript: {transcript}") assert transcript.text is not None assert isinstance(transcript.text, str) From 9b7c8880c09edde866372c5e0910bdba2c0afef7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 9 Mar 2024 16:09:12 -0800 Subject: [PATCH 2/6] fix(caching.py): only add unique kwargs for transcription_only_kwargs in caching --- litellm/caching.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 623b483e88e1..d1add41fcf5b 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -897,12 +897,8 @@ def get_cache_key(self, *args, **kwargs): "encoding_format", ] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs transcription_only_kwargs = [ - "model", "file", "language", - "prompt", - "response_format", - "temperature", ] # combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set() combined_kwargs = ( From c333216f6ec6f75101f817373bdb0c553487d96c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 9 Mar 2024 16:20:14 -0800 Subject: [PATCH 3/6] test(test_custom_logger.py): make test more verbose --- litellm/tests/test_custom_logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index fe1307689066..0a8f7b941642 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -100,7 +100,7 @@ async def async_test_logging_fn(self, kwargs, completion_obj, start_time, end_ti def test_async_chat_openai_stream(): try: tmp_function = TmpFunction() - # litellm.set_verbose = True + litellm.set_verbose = True litellm.success_callback = [tmp_function.async_test_logging_fn] complete_streaming_response = "" From 8d2d51b625cd71f9d44c8ea50ee89ca3a1292176 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 9 Mar 2024 18:22:26 -0800 Subject: [PATCH 4/6] fix(utils.py): fix model name checking --- litellm/llms/openai.py | 1 + litellm/tests/test_custom_callback_input.py | 1 + litellm/utils.py | 24 ++++++++++++++++----- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 3f4b59d4586b..f65d96b11399 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -753,6 +753,7 @@ def image_generation( # return response return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore except OpenAIError as e: + exception_mapping_worked = True ## LOGGING logging_obj.post_call( diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 9249333197b4..5c52867f9394 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -973,6 +973,7 @@ def test_image_generation_openai(): print(f"customHandler_success.errors: {customHandler_success.errors}") print(f"customHandler_success.states: {customHandler_success.states}") + time.sleep(2) assert len(customHandler_success.errors) == 0 assert len(customHandler_success.states) == 3 # pre, post, success # test failure callback diff --git a/litellm/utils.py b/litellm/utils.py index 9568bf3f9dda..ffd82d53d191 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1243,7 +1243,7 @@ def success_handler( ) # print(f"original response in success handler: {self.model_call_details['original_response']}") try: - verbose_logger.debug(f"success callbacks: {litellm.success_callback}") + print_verbose(f"success callbacks: {litellm.success_callback}") ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response = None if self.stream and isinstance(result, ModelResponse): @@ -1266,7 +1266,7 @@ def success_handler( self.sync_streaming_chunks.append(result) if complete_streaming_response is not None: - verbose_logger.debug( + print_verbose( f"Logging Details LiteLLM-Success Call streaming complete" ) self.model_call_details["complete_streaming_response"] = ( @@ -1613,6 +1613,14 @@ def success_handler( "aembedding", False ) == False + and self.model_call_details.get("litellm_params", {}).get( + "aimage_generation", False + ) + == False + and self.model_call_details.get("litellm_params", {}).get( + "atranscription", False + ) + == False ): # custom logger class if self.stream and complete_streaming_response is None: callback.log_stream_event( @@ -1645,6 +1653,14 @@ def success_handler( "aembedding", False ) == False + and self.model_call_details.get("litellm_params", {}).get( + "aimage_generation", False + ) + == False + and self.model_call_details.get("litellm_params", {}).get( + "atranscription", False + ) + == False ): # custom logger functions print_verbose( f"success callbacks: Running Custom Callback Function" @@ -3728,7 +3744,6 @@ def completion_cost( - If an error occurs during execution, the function returns 0.0 without blocking the user's execution path. """ try: - if ( (call_type == "aimage_generation" or call_type == "image_generation") and model is not None @@ -3737,7 +3752,6 @@ def completion_cost( and custom_llm_provider == "azure" ): model = "dall-e-2" # for dall-e-2, azure expects an empty model name - # Handle Inputs to completion_cost prompt_tokens = 0 completion_tokens = 0 @@ -3756,7 +3770,7 @@ def completion_cost( "model", None ) # check if user passed an override for model, if it's none check completion_response['model'] if hasattr(completion_response, "_hidden_params"): - model = completion_response._hidden_params.get("model", model) + model = model or completion_response._hidden_params.get("model", None) custom_llm_provider = completion_response._hidden_params.get( "custom_llm_provider", "" ) From 61d16cb67202efe88fdbdd410ecce66a60f3cae3 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 9 Mar 2024 18:47:20 -0800 Subject: [PATCH 5/6] fix(test_proxy_server.py): fix test --- litellm/tests/test_proxy_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index d5e8f09c680f..3d839b26cdc5 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -336,6 +336,8 @@ def test_load_router_config(): "acompletion", "embedding", "aembedding", + "atranscription", + "transcription", ] # init with all call types litellm.disable_cache() From 1d15dde6de2c9a12a0e94911a575f3763212a025 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 9 Mar 2024 19:11:37 -0800 Subject: [PATCH 6/6] fix(utils.py): fix model setting in completion cost --- litellm/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/litellm/utils.py b/litellm/utils.py index ffd82d53d191..b19bf62fa76d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3770,7 +3770,11 @@ def completion_cost( "model", None ) # check if user passed an override for model, if it's none check completion_response['model'] if hasattr(completion_response, "_hidden_params"): - model = model or completion_response._hidden_params.get("model", None) + if ( + completion_response._hidden_params.get("model", None) is not None + and len(completion_response._hidden_params["model"]) > 0 + ): + model = completion_response._hidden_params.get("model", model) custom_llm_provider = completion_response._hidden_params.get( "custom_llm_provider", "" )