Skip to content

Commit

Permalink
Merge pull request #970 from nbaldwin98/fixing-replicate-sys-prompt
Browse files Browse the repository at this point in the history
fix system prompts for replicate
  • Loading branch information
krrishdholakia authored Dec 5, 2023
2 parents 1247afb + c2e2e92 commit b90fcbd
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 27 deletions.
4 changes: 2 additions & 2 deletions docs/my-website/docs/providers/replicate.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ Below are examples on how to call replicate LLMs using liteLLM

Model Name | Function Call | Required OS Variables |
-----------------------------|----------------------------------------------------------------|--------------------------------------|
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages)` | `os.environ['REPLICATE_API_KEY']` |
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages)`| `os.environ['REPLICATE_API_KEY']` |
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages, supports_system_prompt=True)` | `os.environ['REPLICATE_API_KEY']` |
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages, supports_system_prompt=True)`| `os.environ['REPLICATE_API_KEY']` |
replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` |
daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` |
custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` |
Expand Down
57 changes: 32 additions & 25 deletions litellm/llms/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}")


# Function to extract version ID from model string
def model_to_version_id(model):
Expand All @@ -194,41 +195,47 @@ def completion(
):
# Start a prediction and get the prediction URL
version_id = model_to_version_id(model)

## Load Config
config = litellm.ReplicateConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v

system_prompt = None
if optional_params is not None and "supports_system_prompt" in optional_params:
supports_sys_prompt = optional_params.pop("supports_system_prompt")
else:
supports_sys_prompt = False

if supports_sys_prompt:
for i in range(len(messages)):
if messages[i]["role"] == "system":
first_sys_message = messages.pop(i)
system_prompt = first_sys_message["content"]
break

if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)

if "meta/llama-2-13b-chat" in model:
system_prompt = ""
prompt = ""
for message in messages:
if message["role"] == "system":
system_prompt = message["content"]
else:
prompt += message["content"]
# If system prompt is supported, and a system prompt is provided, use it
if system_prompt is not None:
input_data = {
"system_prompt": system_prompt,
"prompt": prompt,
**optional_params
"system_prompt": system_prompt
}
# Otherwise, use the prompt as is
else:
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)

input_data = {
"prompt": prompt,
**optional_params
Expand Down

0 comments on commit b90fcbd

Please sign in to comment.