From 9c229ae3d15d8cc21ccc58107a91a1b01d6927c8 Mon Sep 17 00:00:00 2001 From: coconut49 Date: Wed, 18 Oct 2023 14:31:43 +0800 Subject: [PATCH] Refactor code for better readability and remove unnecessary comments in Dockerfile. --- Dockerfile | 5 +- litellm/proxy/proxy_server.py | 255 ++++++++++++++++++++++------------ 2 files changed, 171 insertions(+), 89 deletions(-) diff --git a/Dockerfile b/Dockerfile index cd4f86da414f..42b223b1fc24 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,4 @@ RUN pip install -r requirements.txt WORKDIR /app/litellm/proxy EXPOSE 8000 -ENTRYPOINT [ "python3", "proxy_cli.py" ] -# TODO - Set up a GitHub Action to automatically create the Docker image, -# and then we can quickly deploy the litellm proxy in the following way -# `docker run -p 8000:8000 -v ./secrets_template.toml:/root/.config/litellm/litellm.secrets.toml ghcr.io/BerriAI/litellm:v0.8.4` \ No newline at end of file +ENTRYPOINT [ "python3", "proxy_cli.py" ] \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c4ec989df32f..cf0a574f26a4 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -19,7 +19,19 @@ import sys subprocess.check_call( - [sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"]) + [ + sys.executable, + "-m", + "pip", + "install", + "uvicorn", + "fastapi", + "tomli", + "appdirs", + "tomli-w", + "backoff", + ] + ) import uvicorn import fastapi import tomli as tomllib @@ -52,14 +64,17 @@ def generate_feedback_box(): message = random.choice(list_of_messages) print() - print('\033[1;37m' + '#' + '-' * box_width + '#\033[0m') - print('\033[1;37m' + '#' + ' ' * box_width + '#\033[0m') - print('\033[1;37m' + '# {:^59} #\033[0m'.format(message)) - print('\033[1;37m' + '# {:^59} #\033[0m'.format('https://github.com/BerriAI/litellm/issues/new')) - print('\033[1;37m' + '#' + ' ' * box_width + '#\033[0m') - print('\033[1;37m' + '#' + '-' * box_width + '#\033[0m') + print("\033[1;37m" + "#" + "-" * box_width + "#\033[0m") + print("\033[1;37m" + "#" + " " * box_width + "#\033[0m") + print("\033[1;37m" + "# {:^59} #\033[0m".format(message)) + print( + "\033[1;37m" + + "# {:^59} #\033[0m".format("https://github.com/BerriAI/litellm/issues/new") + ) + print("\033[1;37m" + "#" + " " * box_width + "#\033[0m") + print("\033[1;37m" + "#" + "-" * box_width + "#\033[0m") print() - print(' Thank you for using LiteLLM! - Krrish & Ishaan') + print(" Thank you for using LiteLLM! - Krrish & Ishaan") print() print() @@ -67,7 +82,9 @@ def generate_feedback_box(): generate_feedback_box() print() -print("\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m") +print( + "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" +) print() print("\033[1;34mDocs: https://docs.litellm.ai/docs/proxy_server\033[0m") print() @@ -104,8 +121,10 @@ def generate_feedback_box(): config_filename = "litellm.secrets.toml" config_dir = os.getcwd() config_dir = appdirs.user_config_dir("litellm") -user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)) -log_file = 'api_log.json' +user_config_path = os.getenv( + "LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename) +) +log_file = "api_log.json" #### HELPER FUNCTIONS #### @@ -123,12 +142,13 @@ def find_avatar_url(role): def usage_telemetry( - feature: str): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off + feature: str, +): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off if user_telemetry: - data = { - "feature": feature # "local_proxy_server" - } - threading.Thread(target=litellm.utils.litellm_telemetry, args=(data,), daemon=True).start() + data = {"feature": feature} # "local_proxy_server" + threading.Thread( + target=litellm.utils.litellm_telemetry, args=(data,), daemon=True + ).start() def add_keys_to_config(key, value): @@ -141,11 +161,11 @@ def add_keys_to_config(key, value): # File doesn't exist, create empty config config = {} - # Add new key - config.setdefault('keys', {})[key] = value + # Add new key + config.setdefault("keys", {})[key] = value - # Write config to file - with open(user_config_path, 'wb') as f: + # Write config to file + with open(user_config_path, "wb") as f: tomli_w.dump(config, f) @@ -159,15 +179,15 @@ def save_params_to_config(data: dict): # File doesn't exist, create empty config config = {} - config.setdefault('general', {}) + config.setdefault("general", {}) - ## general config + ## general config general_settings = data["general"] for key, value in general_settings.items(): config["general"][key] = value - ## model-specific config + ## model-specific config config.setdefault("model", {}) config["model"].setdefault(user_model, {}) @@ -177,8 +197,8 @@ def save_params_to_config(data: dict): for key, value in user_model_config.items(): config["model"][model_key][key] = value - # Write config to file - with open(user_config_path, 'wb') as f: + # Write config to file + with open(user_config_path, "wb") as f: tomli_w.dump(config, f) @@ -192,16 +212,23 @@ def load_config(): ## load keys if "keys" in user_config: for key in user_config["keys"]: - os.environ[key] = user_config["keys"][key] # litellm can read keys from the environment + os.environ[key] = user_config["keys"][ + key + ] # litellm can read keys from the environment ## settings if "general" in user_config: - litellm.add_function_to_prompt = user_config["general"].get("add_function_to_prompt", - True) # by default add function to prompt if unsupported by provider - litellm.drop_params = user_config["general"].get("drop_params", - True) # by default drop params if unsupported by provider - litellm.model_fallbacks = user_config["general"].get("fallbacks", - None) # fallback models in case initial completion call fails - default_model = user_config["general"].get("default_model", None) # route all requests to this model. + litellm.add_function_to_prompt = user_config["general"].get( + "add_function_to_prompt", True + ) # by default add function to prompt if unsupported by provider + litellm.drop_params = user_config["general"].get( + "drop_params", True + ) # by default drop params if unsupported by provider + litellm.model_fallbacks = user_config["general"].get( + "fallbacks", None + ) # fallback models in case initial completion call fails + default_model = user_config["general"].get( + "default_model", None + ) # route all requests to this model. if user_model is None: # `litellm --model `` > default_model. user_model = default_model @@ -225,32 +252,63 @@ def load_config(): ## custom prompt template if "prompt_template" in model_config: model_prompt_template = model_config["prompt_template"] - if len(model_prompt_template.keys()) > 0: # if user has initialized this at all + if ( + len(model_prompt_template.keys()) > 0 + ): # if user has initialized this at all litellm.register_prompt_template( model=user_model, - initial_prompt_value=model_prompt_template.get("MODEL_PRE_PROMPT", ""), + initial_prompt_value=model_prompt_template.get( + "MODEL_PRE_PROMPT", "" + ), roles={ "system": { - "pre_message": model_prompt_template.get("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""), - "post_message": model_prompt_template.get("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""), + "pre_message": model_prompt_template.get( + "MODEL_SYSTEM_MESSAGE_START_TOKEN", "" + ), + "post_message": model_prompt_template.get( + "MODEL_SYSTEM_MESSAGE_END_TOKEN", "" + ), }, "user": { - "pre_message": model_prompt_template.get("MODEL_USER_MESSAGE_START_TOKEN", ""), - "post_message": model_prompt_template.get("MODEL_USER_MESSAGE_END_TOKEN", ""), + "pre_message": model_prompt_template.get( + "MODEL_USER_MESSAGE_START_TOKEN", "" + ), + "post_message": model_prompt_template.get( + "MODEL_USER_MESSAGE_END_TOKEN", "" + ), }, "assistant": { - "pre_message": model_prompt_template.get("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""), - "post_message": model_prompt_template.get("MODEL_ASSISTANT_MESSAGE_END_TOKEN", ""), - } + "pre_message": model_prompt_template.get( + "MODEL_ASSISTANT_MESSAGE_START_TOKEN", "" + ), + "post_message": model_prompt_template.get( + "MODEL_ASSISTANT_MESSAGE_END_TOKEN", "" + ), + }, }, - final_prompt_value=model_prompt_template.get("MODEL_POST_PROMPT", ""), + final_prompt_value=model_prompt_template.get( + "MODEL_POST_PROMPT", "" + ), ) except: pass -def initialize(model, alias, api_base, api_version, debug, temperature, max_tokens, max_budget, telemetry, drop_params, - add_function_to_prompt, headers, save): +def initialize( + model, + alias, + api_base, + api_version, + debug, + temperature, + max_tokens, + max_budget, + telemetry, + drop_params, + add_function_to_prompt, + headers, + save, +): global user_model, user_api_base, user_debug, user_max_tokens, user_temperature, user_telemetry, user_headers user_model = model user_debug = debug @@ -263,7 +321,9 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base if api_version: - os.environ["AZURE_API_VERSION"] = api_version # set this for azure - litellm can read this from the env + os.environ[ + "AZURE_API_VERSION" + ] = api_version # set this for azure - litellm can read this from the env if max_tokens: # model-specific param user_max_tokens = max_tokens dynamic_config[user_model]["max_tokens"] = max_tokens @@ -293,15 +353,16 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke def track_cost_callback( - kwargs, # kwargs to completion - completion_response, # response from completion - start_time, end_time # start/end time + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, + end_time, # start/end time ): - # track cost like this + # track cost like this # { # "Oct12": { # "gpt-4": 10, - # "claude-2": 12.01, + # "claude-2": 12.01, # }, # "Oct 15": { # "ollama/llama2": 0.0, @@ -309,28 +370,27 @@ def track_cost_callback( # } # } try: - # for streaming responses if "complete_streaming_response" in kwargs: - # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost + # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost completion_response = kwargs["complete_streaming_response"] input_text = kwargs["messages"] output_text = completion_response["choices"][0]["message"]["content"] response_cost = litellm.completion_cost( - model=kwargs["model"], - messages=input_text, - completion=output_text + model=kwargs["model"], messages=input_text, completion=output_text ) - model = kwargs['model'] + model = kwargs["model"] # for non streaming responses else: # we pass the completion_response obj if kwargs["stream"] != True: - response_cost = litellm.completion_cost(completion_response=completion_response) + response_cost = litellm.completion_cost( + completion_response=completion_response + ) model = completion_response["model"] - # read/write from json for storing daily model costs + # read/write from json for storing daily model costs cost_data = {} try: with open("costs.json") as f: @@ -338,6 +398,7 @@ def track_cost_callback( except FileNotFoundError: cost_data = {} import datetime + date = datetime.datetime.now().strftime("%b-%d-%Y") if date not in cost_data: cost_data[date] = {} @@ -348,7 +409,7 @@ def track_cost_callback( else: cost_data[date][kwargs["model"]] = { "cost": response_cost, - "num_requests": 1 + "num_requests": 1, } with open("costs.json", "w") as f: @@ -359,25 +420,21 @@ def track_cost_callback( def logger( - kwargs, # kwargs to completion - completion_response=None, # response from completion - start_time=None, - end_time=None # start/end time + kwargs, # kwargs to completion + completion_response=None, # response from completion + start_time=None, + end_time=None, # start/end time ): - log_event_type = kwargs['log_event_type'] + log_event_type = kwargs["log_event_type"] try: - if log_event_type == 'pre_api_call': + if log_event_type == "pre_api_call": inference_params = copy.deepcopy(kwargs) - timestamp = inference_params.pop('start_time') + timestamp = inference_params.pop("start_time") dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23] - log_data = { - dt_key: { - 'pre_api_call': inference_params - } - } + log_data = {dt_key: {"pre_api_call": inference_params}} try: - with open(log_file, 'r') as f: + with open(log_file, "r") as f: existing_data = json.load(f) except FileNotFoundError: existing_data = {} @@ -385,7 +442,7 @@ def logger( existing_data.update(log_data) def write_to_log(): - with open(log_file, 'w') as f: + with open(log_file, "w") as f: json.dump(existing_data, f, indent=2) thread = threading.Thread(target=write_to_log, daemon=True) @@ -423,14 +480,28 @@ def write_to_log(): def model_list(): if user_model != None: return dict( - data=[{"id": user_model, "object": "model", "created": 1677610602, "owned_by": "openai"}], + data=[ + { + "id": user_model, + "object": "model", + "created": 1677610602, + "owned_by": "openai", + } + ], object="list", ) else: all_models = litellm.utils.get_valid_models() return dict( - data=[{"id": model, "object": "model", "created": 1677610602, "owned_by": "openai"} for model in - all_models], + data=[ + { + "id": model, + "object": "model", + "created": 1677610602, + "owned_by": "openai", + } + for model in all_models + ], object="list", ) @@ -439,9 +510,16 @@ def model_list(): @router.post("/completions") async def completion(request: Request): data = await request.json() - return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature, - user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, - user_debug=user_debug) + return litellm_completion( + data=data, + type="completion", + user_model=user_model, + user_temperature=user_temperature, + user_max_tokens=user_max_tokens, + user_api_base=user_api_base, + user_headers=user_headers, + user_debug=user_debug, + ) @router.post("/v1/chat/completions") @@ -449,13 +527,20 @@ async def completion(request: Request): async def chat_completion(request: Request): data = await request.json() print_verbose(f"data passed in: {data}") - return litellm_completion(data, type="chat_completion", user_model=user_model, - user_temperature=user_temperature, user_max_tokens=user_max_tokens, - user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug) + return litellm_completion( + data, + type="chat_completion", + user_model=user_model, + user_temperature=user_temperature, + user_max_tokens=user_max_tokens, + user_api_base=user_api_base, + user_headers=user_headers, + user_debug=user_debug, + ) def print_cost_logs(): - with open('costs.json', 'r') as f: + with open("costs.json", "r") as f: # print this in green print("\033[1;32m") print(f.read()) @@ -465,7 +550,7 @@ def print_cost_logs(): @router.get("/ollama_logs") async def retrieve_server_log(request: Request): - filepath = os.path.expanduser('~/.ollama/logs/server.log') + filepath = os.path.expanduser("~/.ollama/logs/server.log") return FileResponse(filepath)