Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Google AI Gemini Pro support #1209

Merged
merged 13 commits into from
Apr 11, 2024
14 changes: 7 additions & 7 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from memgpt.persistence_manager import LocalStateManager
from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
from memgpt.memory import CoreMemory as InContextMemory, summarize_messages, ArchivalMemory, RecallMemory
from memgpt.llm_api_tools import create, is_context_overflow_error
from memgpt.llm_api.llm_api_tools import create, is_context_overflow_error
from memgpt.utils import (
get_utc_time,
create_random_username,
Expand Down Expand Up @@ -114,8 +114,8 @@
recall_memory: Optional[RecallMemory] = None,
include_char_count: bool = True,
):
full_system_message = "\n".join(

Check failure on line 117 in memgpt/agent.py

View workflow job for this annotation

GitHub Actions / Pyright types check (3.11)

No overloads for "join" match the provided arguments (reportCallIssue)
[

Check failure on line 118 in memgpt/agent.py

View workflow job for this annotation

GitHub Actions / Pyright types check (3.11)

Argument of type "list[str | Unknown | None]" cannot be assigned to parameter "iterable" of type "Iterable[str]" in function "join"   Type "Unknown | None" cannot be assigned to type "str"     "None" is incompatible with "str"   Type "Unknown | None" cannot be assigned to type "str"     "None" is incompatible with "str" (reportArgumentType)
system,
"\n",
f"### Memory [last modified: {memory_edit_timestamp.strip()}]",
Expand Down Expand Up @@ -400,7 +400,7 @@

def _get_ai_reply(
self,
message_sequence: List[dict],
message_sequence: List[Message],
function_call: str = "auto",
first_message: bool = False, # hint
) -> chat_completion_response.ChatCompletionResponse:
Expand Down Expand Up @@ -694,12 +694,12 @@

self.interface.user_message(user_message.text, msg_obj=user_message)

input_message_sequence = self.messages + [user_message.to_openai_dict()]
input_message_sequence = self._messages + [user_message]
# Alternatively, the requestor can send an empty user message
else:
input_message_sequence = self.messages
input_message_sequence = self._messages

if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user":
if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")

# Step 1: send the conversation and available functions to GPT
Expand Down Expand Up @@ -858,14 +858,14 @@
printd(f"Selected cutoff {cutoff} was a 'tool', shifting one...")
cutoff += 1

message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message
message_sequence_to_summarize = self._messages[1:cutoff] # do NOT get rid of the system message
if len(message_sequence_to_summarize) <= 1:
# This prevents a potential infinite loop of summarizing the same message over and over
raise LLMError(
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(message_sequence_to_summarize)} <= 1]"
)
else:
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}")
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self._messages)}")

# We can't do summarize logic properly if context_window is undefined
if self.agent_state.llm_config.context_window is None:
Expand Down
139 changes: 127 additions & 12 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from memgpt.constants import MEMGPT_DIR
from memgpt.credentials import MemGPTCredentials, SUPPORTED_AUTH_TYPES
from memgpt.data_types import User, LLMConfig, EmbeddingConfig
from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
from memgpt.llm_api.openai import openai_get_model_list
from memgpt.llm_api.azure_openai import azure_openai_get_model_list
from memgpt.llm_api.google_ai import google_ai_get_model_list, google_ai_get_model_context_window
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
from memgpt.local_llm.utils import get_available_wrappers
from memgpt.server.utils import shorten_key_middle
Expand All @@ -45,11 +47,16 @@ def get_azure_credentials():
return creds


def get_openai_credentials():
openai_key = os.getenv("OPENAI_API_KEY")
def get_openai_credentials() -> Optional[str]:
openai_key = os.getenv("OPENAI_API_KEY", None)
return openai_key


def get_google_ai_credentials() -> Optional[str]:
google_ai_key = os.getenv("GOOGLE_AI_API_KEY", None)
return google_ai_key


def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials):
# configure model endpoint
model_endpoint_type, model_endpoint = None, None
Expand All @@ -59,11 +66,12 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
if config.default_llm_config.model_endpoint_type is not None and config.default_llm_config.model_endpoint_type not in [
"openai",
"azure",
"google_ai",
]: # local model
default_model_endpoint_type = "local"

provider = questionary.select(
"Select LLM inference provider:", choices=["openai", "azure", "local"], default=default_model_endpoint_type
"Select LLM inference provider:", choices=["openai", "azure", "google_ai", "local"], default=default_model_endpoint_type
).ask()
if provider is None:
raise KeyboardInterrupt
Expand Down Expand Up @@ -131,6 +139,51 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
model_endpoint_type = "azure"
model_endpoint = azure_creds["azure_endpoint"]

elif provider == "google_ai":

# check for key
if credentials.google_ai_key is None:
# allow key to get pulled from env vars
google_ai_key = get_google_ai_credentials()
# if we still can't find it, ask for it as input
if google_ai_key is None:
while google_ai_key is None or len(google_ai_key) == 0:
# Ask for API key as input
google_ai_key = questionary.password(
"Enter your Google AI (Gemini) API key (see https://aistudio.google.com/app/apikey):"
).ask()
if google_ai_key is None:
raise KeyboardInterrupt
credentials.google_ai_key = google_ai_key
else:
# Give the user an opportunity to overwrite the key
google_ai_key = None
default_input = shorten_key_middle(credentials.google_ai_key)

google_ai_key = questionary.password(
"Enter your Google AI (Gemini) API key (see https://aistudio.google.com/app/apikey):",
default=default_input,
).ask()
if google_ai_key is None:
raise KeyboardInterrupt
# If the user modified it, use the new one
if google_ai_key != default_input:
credentials.google_ai_key = google_ai_key

default_input = os.getenv("GOOGLE_AI_SERVICE_ENDPOINT", None)
if default_input is None:
default_input = "generativelanguage"
google_ai_service_endpoint = questionary.text(
"Enter your Google AI (Gemini) service endpoint (see https://ai.google.dev/api/rest):",
default=default_input,
).ask()
credentials.google_ai_service_endpoint = google_ai_service_endpoint

# write out the credentials
credentials.save()

model_endpoint_type = "google_ai"

else: # local models
# backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
backend_options = builtins.list(DEFAULT_ENDPOINTS.keys())
Expand Down Expand Up @@ -223,6 +276,21 @@ def get_model_options(
else:
model_options = [obj["id"] for obj in fetched_model_options_response["data"]]

elif model_endpoint_type == "google_ai":
if credentials.google_ai_key is None:
raise ValueError("Missing Google AI API key")
if credentials.google_ai_service_endpoint is None:
raise ValueError("Missing Google AI service endpoint")
model_options = google_ai_get_model_list(
service_endpoint=credentials.google_ai_service_endpoint, api_key=credentials.google_ai_key
)
model_options = [str(m["name"]) for m in model_options]
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]

# TODO remove manual filtering for gemini-pro
model_options = [mo for mo in model_options if str(mo).startswith("gemini") and "-pro" in str(mo)]
# model_options = ["gemini-pro"]

else:
# Attempt to do OpenAI endpoint style model fetching
# TODO support local auth with api-key header
Expand Down Expand Up @@ -294,6 +362,26 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if model is None:
raise KeyboardInterrupt

elif model_endpoint_type == "google_ai":
try:
fetched_model_options = get_model_options(
credentials=credentials, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
)
except Exception as e:
# NOTE: if this fails, it means the user's key is probably bad
typer.secho(
f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED
)
raise e

model = questionary.select(
"Select default model:",
choices=fetched_model_options,
default=fetched_model_options[0],
).ask()
if model is None:
raise KeyboardInterrupt

else: # local models

# ask about local auth
Expand Down Expand Up @@ -412,7 +500,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_

# set: context_window
if str(model) not in LLM_MAX_TOKENS:
# Ask the user to specify the context length

context_length_options = [
str(2**12), # 4096
str(2**13), # 8192
Expand All @@ -421,13 +509,40 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
str(2**18), # 262144
"custom", # enter yourself
]
context_window_input = questionary.select(
"Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):",
choices=context_length_options,
default=str(LLM_MAX_TOKENS["DEFAULT"]),
).ask()
if context_window_input is None:
raise KeyboardInterrupt

if model_endpoint_type == "google_ai":
try:
fetched_context_window = str(
google_ai_get_model_context_window(
service_endpoint=credentials.google_ai_service_endpoint, api_key=credentials.google_ai_key, model=model
)
)
print(f"Got context window {fetched_context_window} for model {model} (from Google API)")
context_length_options = [
fetched_context_window,
"custom",
]
except:
print(f"Failed to get model details for model '{model}' on Google AI API")

context_window_input = questionary.select(
"Select your model's context window (see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions):",
choices=context_length_options,
default=context_length_options[0],
).ask()
if context_window_input is None:
raise KeyboardInterrupt

else:

# Ask the user to specify the context length
context_window_input = questionary.select(
"Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):",
choices=context_length_options,
default=str(LLM_MAX_TOKENS["DEFAULT"]),
).ask()
if context_window_input is None:
raise KeyboardInterrupt

# If custom, ask for input
if context_window_input == "custom":
Expand Down
2 changes: 2 additions & 0 deletions memgpt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
)
# The fraction of tokens we truncate down to
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC = 0.75
# The ackknowledgement message used in the summarize sequence
MESSAGE_SUMMARY_REQUEST_ACK = "Understood, I will respond with a summary of the message (and only the summary, nothing else) once I receive the conversation history. I'm ready."

# Even when summarizing, we want to keep a handful of recent messages
# These serve as in-context examples of how to use functions / what user messages look like
Expand Down
13 changes: 12 additions & 1 deletion memgpt/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class MemGPTCredentials:
openai_auth_type: str = "bearer_token"
openai_key: Optional[str] = None

# gemini config
google_ai_key: Optional[str] = None
google_ai_service_endpoint: Optional[str] = None

# azure config
azure_auth_type: str = "api_key"
azure_key: Optional[str] = None
Expand Down Expand Up @@ -70,6 +74,9 @@ def load(cls) -> "MemGPTCredentials":
"azure_embedding_version": get_field(config, "azure", "embedding_version"),
"azure_embedding_endpoint": get_field(config, "azure", "embedding_endpoint"),
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
# gemini
"google_ai_key": get_field(config, "google_ai", "key"),
"google_ai_service_endpoint": get_field(config, "google_ai", "service_endpoint"),
# open llm
"openllm_auth_type": get_field(config, "openllm", "auth_type"),
"openllm_key": get_field(config, "openllm", "key"),
Expand Down Expand Up @@ -102,7 +109,11 @@ def save(self):
set_field(config, "azure", "embedding_endpoint", self.azure_embedding_endpoint)
set_field(config, "azure", "embedding_deployment", self.azure_embedding_deployment)

# openai config
# gemini
set_field(config, "google_ai", "key", self.google_ai_key)
set_field(config, "google_ai", "service_endpoint", self.google_ai_service_endpoint)

# openllm config
set_field(config, "openllm", "auth_type", self.openllm_auth_type)
set_field(config, "openllm", "key", self.openllm_key)

Expand Down
Loading
Loading