Skip to content

Commit

Permalink
feat: pull model list for openai-compatible endpoints (#630)
Browse files Browse the repository at this point in the history
* allow entering custom model name when using openai/azure

* pull models from endpoint

* added/tested vllm and azure

* no print

* make red

* make the endpoint question give you an opportunity to enter your openai api key again in case you made a mitake / want to swap it out

* add cascading workflow for openai+azure model listings

* patched bug w/ azure listing
  • Loading branch information
cpacker authored Dec 22, 2023
1 parent 09c7fa7 commit b97064e
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 20 deletions.
129 changes: 109 additions & 20 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from memgpt.constants import LLM_MAX_TOKENS
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
from memgpt.local_llm.utils import get_available_wrappers
from memgpt.openai_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
from memgpt.server.utils import shorten_key_middle

app = typer.Typer()

Expand Down Expand Up @@ -63,6 +65,18 @@ def configure_llm_endpoint(config: MemGPTConfig):
openai_api_key = questionary.text(
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
).ask()
config.openai_key = openai_api_key
config.save()
else:
# Give the user an opportunity to overwrite the key
openai_api_key = None
default_input = shorten_key_middle(config.openai_key) if config.openai_key.startswith("sk-") else config.openai_key
openai_api_key = questionary.text(
"Enter your OpenAI API key (hit enter to use existing key):",
default=default_input,
).ask()
# If the user modified it, use the new one
if openai_api_key != default_input:
config.openai_key = openai_api_key
config.save()

Expand All @@ -78,6 +92,11 @@ def configure_llm_endpoint(config: MemGPTConfig):
raise ValueError(
"Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again."
)
else:
config.azure_key = azure_creds["azure_key"]
config.azure_endpoint = azure_creds["azure_endpoint"]
config.azure_version = azure_creds["azure_version"]
config.save()

model_endpoint_type = "azure"
model_endpoint = azure_creds["azure_endpoint"]
Expand Down Expand Up @@ -119,16 +138,56 @@ def configure_llm_endpoint(config: MemGPTConfig):
return model_endpoint_type, model_endpoint


def configure_model(config: MemGPTConfig, model_endpoint_type: str):
def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoint: str):
# set: model, model_wrapper
model, model_wrapper = None, None
if model_endpoint_type == "openai" or model_endpoint_type == "azure":
model_options = ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"]
# TODO: select
valid_model = config.model in model_options
# Get the model list from the openai / azure endpoint
hardcoded_model_options = ["gpt-4", "gpt-4-32k", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"]
fetched_model_options = None
try:
if model_endpoint_type == "openai":
fetched_model_options = openai_get_model_list(url=model_endpoint, api_key=config.openai_key)
elif model_endpoint_type == "azure":
fetched_model_options = azure_openai_get_model_list(
url=model_endpoint, api_key=config.azure_key, api_version=config.azure_version
)
fetched_model_options = [obj["id"] for obj in fetched_model_options["data"] if obj["id"].startswith("gpt-")]
except:
# NOTE: if this fails, it means the user's key is probably bad
typer.secho(
f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED
)

# First ask if the user wants to see the full model list (some may be incompatible)
see_all_option_str = "[see all options]"
other_option_str = "[enter model name manually]"

# Check if the model we have set already is even in the list (informs our default)
valid_model = config.model in hardcoded_model_options
model = questionary.select(
"Select default model (recommended: gpt-4):", choices=model_options, default=config.model if valid_model else model_options[0]
"Select default model (recommended: gpt-4):",
choices=hardcoded_model_options + [see_all_option_str, other_option_str],
default=config.model if valid_model else hardcoded_model_options[0],
).ask()

# If the user asked for the full list, show it
if model == see_all_option_str:
typer.secho(f"Warning: not all models shown are guaranteed to work with MemGPT", fg=typer.colors.RED)
model = questionary.select(
"Select default model (recommended: gpt-4):",
choices=fetched_model_options + [other_option_str],
default=config.model if valid_model else fetched_model_options[0],
).ask()

# Finally if the user asked to manually input, allow it
if model == other_option_str:
model = ""
while len(model) == 0:
model = questionary.text(
"Enter custom model name:",
).ask()

else: # local models
# ollama also needs model type
if model_endpoint_type == "ollama":
Expand All @@ -139,24 +198,51 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str):
).ask()
model = None if len(model) == 0 else model

default_model = config.model if config.model and config.model_endpoint_type == "vllm" else ""

# vllm needs huggingface model tag
if model_endpoint_type == "vllm":
default_model = config.model if config.model and config.model_endpoint_type == "vllm" else ""
model = questionary.text(
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
default=default_model,
).ask()
model = None if len(model) == 0 else model
model_wrapper = None # no model wrapper for vLLM
try:
# Don't filter model list for vLLM since model list is likely much smaller than OpenAI/Azure endpoint
# + probably has custom model names
model_options = openai_get_model_list(url=smart_urljoin(model_endpoint, "v1"), api_key=None)
model_options = [obj["id"] for obj in model_options["data"]]
except:
print(f"Failed to get model list from {model_endpoint}, using defaults")
model_options = None

# If we got model options from vLLM endpoint, allow selection + custom input
if model_options is not None:
other_option_str = "other (enter name)"
valid_model = config.model in model_options
model_options.append(other_option_str)
model = questionary.select(
"Select default model:", choices=model_options, default=config.model if valid_model else model_options[0]
).ask()

# If we got custom input, ask for raw input
if model == other_option_str:
model = questionary.text(
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
default=default_model,
).ask()
# TODO allow empty string for input?
model = None if len(model) == 0 else model

else:
model = questionary.text(
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
default=default_model,
).ask()
model = None if len(model) == 0 else model

# model wrapper
if model_endpoint_type != "vllm":
available_model_wrappers = builtins.list(get_available_wrappers().keys())
model_wrapper = questionary.select(
f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):",
choices=available_model_wrappers,
default=DEFAULT_WRAPPER_NAME,
).ask()
available_model_wrappers = builtins.list(get_available_wrappers().keys())
model_wrapper = questionary.select(
f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):",
choices=available_model_wrappers,
default=DEFAULT_WRAPPER_NAME,
).ask()

# set: context_window
if str(model) not in LLM_MAX_TOKENS:
Expand Down Expand Up @@ -228,6 +314,7 @@ def configure_embedding_endpoint(config: MemGPTConfig):
raise ValueError(
"Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again."
)
# TODO we need to write these out to the config once we use them if we plan to ping for embedding lists with them

embedding_endpoint_type = "azure"
embedding_endpoint = azure_creds["azure_embedding_endpoint"]
Expand Down Expand Up @@ -345,7 +432,9 @@ def configure():
config = MemGPTConfig.load()
try:
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
model, model_wrapper, context_window = configure_model(
config=config, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
)
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
default_preset, default_persona, default_human, default_agent = configure_cli(config)
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)
Expand Down
89 changes: 89 additions & 0 deletions memgpt/openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import requests
import time
from typing import Callable, TypeVar, Union
import urllib

from box import Box
Expand Down Expand Up @@ -75,6 +76,94 @@ def clean_azure_endpoint(raw_endpoint_name):
return endpoint_address


def openai_get_model_list(url: str, api_key: Union[str, None]) -> dict:
"""https://platform.openai.com/docs/api-reference/models/list"""
from memgpt.utils import printd

url = smart_urljoin(url, "models")

headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"

printd(f"Sending request to {url}")
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response = {response}")
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
raise e


def azure_openai_get_model_list(url: str, api_key: Union[str, None], api_version: str) -> dict:
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
from memgpt.utils import printd

# https://xxx.openai.azure.com/openai/models?api-version=xxx
url = smart_urljoin(url, "openai")
url = smart_urljoin(url, f"models?api-version={api_version}")

headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["api-key"] = f"{api_key}"

printd(f"Sending request to {url}")
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response = {response}")
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
raise e


def openai_chat_completions_request(url, api_key, data):
"""https://platform.openai.com/docs/guides/text-generation?lang=curl"""
from memgpt.utils import printd
Expand Down

0 comments on commit b97064e

Please sign in to comment.