Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Cohere API support (Command-R+) #1246

Merged
merged 4 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from memgpt.llm_api.azure_openai import azure_openai_get_model_list
from memgpt.llm_api.google_ai import google_ai_get_model_list, google_ai_get_model_context_window
from memgpt.llm_api.anthropic import anthropic_get_model_list, antropic_get_model_context_window
from memgpt.llm_api.cohere import cohere_get_model_list, cohere_get_model_context_window, COHERE_VALID_MODEL_LIST
from memgpt.llm_api.llm_api_tools import LLM_API_PROVIDER_OPTIONS
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
from memgpt.local_llm.utils import get_available_wrappers
Expand Down Expand Up @@ -226,6 +227,44 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
raise KeyboardInterrupt
provider = "anthropic"

elif provider == "cohere":
# check for key
if credentials.cohere_key is None:
# allow key to get pulled from env vars
cohere_api_key = os.getenv("COHERE_API_KEY", None)
# if we still can't find it, ask for it as input
if cohere_api_key is None:
while cohere_api_key is None or len(cohere_api_key) == 0:
# Ask for API key as input
cohere_api_key = questionary.password("Enter your Cohere API key (see https://dashboard.cohere.com/api-keys):").ask()
if cohere_api_key is None:
raise KeyboardInterrupt
credentials.cohere_key = cohere_api_key
credentials.save()
else:
# Give the user an opportunity to overwrite the key
cohere_api_key = None
default_input = (
shorten_key_middle(credentials.cohere_key) if credentials.cohere_key.startswith("sk-") else credentials.cohere_key
)
cohere_api_key = questionary.password(
"Enter your Cohere API key (see https://dashboard.cohere.com/api-keys):",
default=default_input,
).ask()
if cohere_api_key is None:
raise KeyboardInterrupt
# If the user modified it, use the new one
if cohere_api_key != default_input:
credentials.cohere_key = cohere_api_key
credentials.save()

model_endpoint_type = "cohere"
model_endpoint = "https://api.cohere.ai/v1"
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
if model_endpoint is None:
raise KeyboardInterrupt
provider = "cohere"

else: # local models
# backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
backend_options = builtins.list(DEFAULT_ENDPOINTS.keys())
Expand Down Expand Up @@ -339,6 +378,12 @@ def get_model_options(
fetched_model_options = anthropic_get_model_list(url=model_endpoint, api_key=credentials.anthropic_key)
model_options = [obj["name"] for obj in fetched_model_options]

elif model_endpoint_type == "cohere":
if credentials.cohere_key is None:
raise ValueError("Missing Cohere API key")
fetched_model_options = cohere_get_model_list(url=model_endpoint, api_key=credentials.cohere_key)
model_options = [obj for obj in fetched_model_options]

else:
# Attempt to do OpenAI endpoint style model fetching
# TODO support local auth with api-key header
Expand Down Expand Up @@ -450,6 +495,58 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if model is None:
raise KeyboardInterrupt

elif model_endpoint_type == "cohere":

fetched_model_options = []
try:
fetched_model_options = get_model_options(
credentials=credentials, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
)
except Exception as e:
# NOTE: if this fails, it means the user's key is probably bad
typer.secho(
f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED
)
raise e

fetched_model_options = [m["name"] for m in fetched_model_options]
hardcoded_model_options = [m for m in fetched_model_options if m in COHERE_VALID_MODEL_LIST]

# First ask if the user wants to see the full model list (some may be incompatible)
see_all_option_str = "[see all options]"
other_option_str = "[enter model name manually]"

# Check if the model we have set already is even in the list (informs our default)
valid_model = config.default_llm_config.model in hardcoded_model_options
model = questionary.select(
"Select default model (recommended: command-r-plus):",
choices=hardcoded_model_options + [see_all_option_str, other_option_str],
default=config.default_llm_config.model if valid_model else hardcoded_model_options[0],
).ask()
if model is None:
raise KeyboardInterrupt

# If the user asked for the full list, show it
if model == see_all_option_str:
typer.secho(f"Warning: not all models shown are guaranteed to work with MemGPT", fg=typer.colors.RED)
model = questionary.select(
"Select default model (recommended: command-r-plus):",
choices=fetched_model_options + [other_option_str],
default=config.default_llm_config.model if valid_model else fetched_model_options[0],
).ask()
if model is None:
raise KeyboardInterrupt

# Finally if the user asked to manually input, allow it
if model == other_option_str:
model = ""
while len(model) == 0:
model = questionary.text(
"Enter custom model name:",
).ask()
if model is None:
raise KeyboardInterrupt

else: # local models

# ask about local auth
Expand Down Expand Up @@ -622,6 +719,27 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if context_window_input is None:
raise KeyboardInterrupt

elif model_endpoint_type == "cohere":
try:
fetched_context_window = str(
cohere_get_model_context_window(url=model_endpoint, api_key=credentials.cohere_key, model=model)
)
print(f"Got context window {fetched_context_window} for model {model}")
context_length_options = [
fetched_context_window,
"custom",
]
except Exception as e:
print(f"Failed to get model details for model '{model}' ({str(e)})")

context_window_input = questionary.select(
"Select your model's context window (see https://docs.cohere.com/docs/command-r):",
choices=context_length_options,
default=context_length_options[0],
).ask()
if context_window_input is None:
raise KeyboardInterrupt

else:

# Ask the user to specify the context length
Expand Down
8 changes: 8 additions & 0 deletions memgpt/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class MemGPTCredentials:
# anthropic config
anthropic_key: Optional[str] = None

# cohere config
cohere_key: Optional[str] = None

# azure config
azure_auth_type: str = "api_key"
azure_key: Optional[str] = None
Expand Down Expand Up @@ -82,6 +85,8 @@ def load(cls) -> "MemGPTCredentials":
"google_ai_service_endpoint": get_field(config, "google_ai", "service_endpoint"),
# anthropic
"anthropic_key": get_field(config, "anthropic", "key"),
# cohere
"cohere_key": get_field(config, "cohere", "key"),
# open llm
"openllm_auth_type": get_field(config, "openllm", "auth_type"),
"openllm_key": get_field(config, "openllm", "key"),
Expand Down Expand Up @@ -121,6 +126,9 @@ def save(self):
# anthropic
set_field(config, "anthropic", "key", self.anthropic_key)

# cohere
set_field(config, "cohere", "key", self.cohere_key)

# openllm config
set_field(config, "openllm", "auth_type", self.openllm_auth_type)
set_field(config, "openllm", "key", self.openllm_key)
Expand Down
102 changes: 102 additions & 0 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,108 @@ def to_google_ai_dict(self, put_inner_thoughts_in_kwargs: bool = True) -> dict:

return google_ai_message

def to_cohere_dict(
self,
function_call_role: Optional[str] = "SYSTEM",
function_call_prefix: Optional[str] = "[CHATBOT called function]",
function_response_role: Optional[str] = "SYSTEM",
function_response_prefix: Optional[str] = "[CHATBOT function returned]",
inner_thoughts_as_kwarg: Optional[bool] = False,
) -> List[dict]:
"""Cohere chat_history dicts only have 'role' and 'message' fields

NOTE: returns a list of dicts so that we can convert:
assistant [cot]: "I'll send a message"
assistant [func]: send_message("hi")
tool: {'status': 'OK'}
to:
CHATBOT.text: "I'll send a message"
SYSTEM.text: [CHATBOT called function] send_message("hi")
SYSTEM.text: [CHATBOT function returned] {'status': 'OK'}

TODO: update this prompt style once guidance from Cohere on
embedded function calls in multi-turn conversation become more clear
"""

if self.role == "system":
"""
The chat_history parameter should not be used for SYSTEM messages in most cases.
Instead, to add a SYSTEM role message at the beginning of a conversation, the preamble parameter should be used.
"""
raise UserWarning(f"role 'system' messages should go in 'preamble' field for Cohere API")

elif self.role == "user":
assert all([v is not None for v in [self.text, self.role]]), vars(self)
cohere_message = [
{
"role": "USER",
"message": self.text,
}
]

elif self.role == "assistant":
# NOTE: we may break this into two message - an inner thought and a function call
# Optionally, we could just make this a function call with the inner thought inside
assert self.tool_calls is not None or self.text is not None

if self.text and self.tool_calls:
if inner_thoughts_as_kwarg:
raise NotImplementedError
cohere_message = [
{
"role": "CHATBOT",
"message": self.text,
},
]
for tc in self.tool_calls:
# TODO better way to pack?
# function_call_text = json.dumps(tc.to_dict())
function_name = tc.function["name"]
function_args = json.loads(tc.function["arguments"])
function_args_str = ",".join([f"{k}={v}" for k, v in function_args.items()])
function_call_text = f"{function_name}({function_args_str})"
cohere_message.append(
{
"role": function_call_role,
"message": f"{function_call_prefix} {function_call_text}",
}
)
elif not self.text and self.tool_calls:
cohere_message = []
for tc in self.tool_calls:
# TODO better way to pack?
function_call_text = json.dumps(tc.to_dict())
cohere_message.append(
{
"role": function_call_role,
"message": f"{function_call_prefix} {function_call_text}",
}
)
elif self.text and not self.tool_calls:
cohere_message = [
{
"role": "CHATBOT",
"message": self.text,
}
]
else:
raise ValueError("Message does not have content nor tool_calls")

elif self.role == "tool":
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
function_response_text = self.text
cohere_message = [
{
"role": function_response_role,
"message": f"{function_response_prefix} {function_response_text}",
}
]

else:
raise ValueError(self.role)

return cohere_message


class Document(Record):
"""A document represent a document loaded into MemGPT, which is broken down into passages."""
Expand Down
Loading
Loading