From c2e2e927fb68a3bafc1606d3027caccadb882918 Mon Sep 17 00:00:00 2001 From: chabala98 Date: Fri, 1 Dec 2023 13:16:35 +0100 Subject: [PATCH] fix system prompts for replicate --- docs/my-website/docs/providers/replicate.md | 4 +- litellm/llms/replicate.py | 57 ++++++++++++--------- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/docs/my-website/docs/providers/replicate.md b/docs/my-website/docs/providers/replicate.md index d8ab035e1b01..3384ba35c283 100644 --- a/docs/my-website/docs/providers/replicate.md +++ b/docs/my-website/docs/providers/replicate.md @@ -49,8 +49,8 @@ Below are examples on how to call replicate LLMs using liteLLM Model Name | Function Call | Required OS Variables | -----------------------------|----------------------------------------------------------------|--------------------------------------| - replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages)` | `os.environ['REPLICATE_API_KEY']` | - a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages)`| `os.environ['REPLICATE_API_KEY']` | + replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages, supports_system_prompt=True)` | `os.environ['REPLICATE_API_KEY']` | + a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages, supports_system_prompt=True)`| `os.environ['REPLICATE_API_KEY']` | replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` | daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` | custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` | diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index d639a8d1e43d..874b31bd6a69 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -169,6 +169,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos else: # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}") + # Function to extract version ID from model string def model_to_version_id(model): @@ -194,41 +195,47 @@ def completion( ): # Start a prediction and get the prediction URL version_id = model_to_version_id(model) - ## Load Config config = litellm.ReplicateConfig.get_config() for k, v in config.items(): if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v + + system_prompt = None + if optional_params is not None and "supports_system_prompt" in optional_params: + supports_sys_prompt = optional_params.pop("supports_system_prompt") + else: + supports_sys_prompt = False + + if supports_sys_prompt: + for i in range(len(messages)): + if messages[i]["role"] == "system": + first_sys_message = messages.pop(i) + system_prompt = first_sys_message["content"] + break + + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", {}), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + bos_token=model_prompt_details.get("bos_token", ""), + eos_token=model_prompt_details.get("eos_token", ""), + messages=messages, + ) + else: + prompt = prompt_factory(model=model, messages=messages) - if "meta/llama-2-13b-chat" in model: - system_prompt = "" - prompt = "" - for message in messages: - if message["role"] == "system": - system_prompt = message["content"] - else: - prompt += message["content"] + # If system prompt is supported, and a system prompt is provided, use it + if system_prompt is not None: input_data = { - "system_prompt": system_prompt, "prompt": prompt, - **optional_params + "system_prompt": system_prompt } + # Otherwise, use the prompt as is else: - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", {}), - initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - bos_token=model_prompt_details.get("bos_token", ""), - eos_token=model_prompt_details.get("eos_token", ""), - messages=messages, - ) - else: - prompt = prompt_factory(model=model, messages=messages) - input_data = { "prompt": prompt, **optional_params