Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add VLLMProvider #1866

Merged
merged 5 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions letta/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL
from letta.log import get_logger
from letta.metadata import MetadataStore
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import OptionState
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory, Memory
from letta.server.server import logger as server_logger

Expand Down Expand Up @@ -235,12 +233,7 @@ def run(
# choose from list of llm_configs
llm_configs = client.list_llm_configs()
llm_options = [llm_config.model for llm_config in llm_configs]

# TODO move into LLMConfig as a class method?
def prettify_llm_config(llm_config: LLMConfig) -> str:
return f"{llm_config.model}" + f" ({llm_config.model_endpoint})" if llm_config.model_endpoint else ""

llm_choices = [questionary.Choice(title=prettify_llm_config(llm_config), value=llm_config) for llm_config in llm_configs]
llm_choices = [questionary.Choice(title=llm_config.pretty_print(), value=llm_config) for llm_config in llm_configs]

# select model
if len(llm_options) == 0:
Expand All @@ -255,17 +248,8 @@ def prettify_llm_config(llm_config: LLMConfig) -> str:
embedding_configs = client.list_embedding_configs()
embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs]

# TODO move into EmbeddingConfig as a class method?
def prettify_embed_config(embedding_config: EmbeddingConfig) -> str:
return (
f"{embedding_config.embedding_model}" + f" ({embedding_config.embedding_endpoint})"
if embedding_config.embedding_endpoint
else ""
)

embedding_choices = [
questionary.Choice(title=prettify_embed_config(embedding_config), value=embedding_config)
for embedding_config in embedding_configs
questionary.Choice(title=embedding_config.pretty_print(), value=embedding_config) for embedding_config in embedding_configs
]

# select model
Expand Down
4 changes: 4 additions & 0 deletions letta/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)

except requests.exceptions.HTTPError as http_err:

if not hasattr(http_err, "response") or not http_err.response:
raise

# Retry on specified errors
if http_err.response.status_code in error_codes:
# Increment retries
Expand Down
10 changes: 7 additions & 3 deletions letta/llm_api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def openai_get_model_list(
headers["Authorization"] = f"Bearer {api_key}"

printd(f"Sending request to {url}")
response = None
try:
# TODO add query param "tool" to be true
response = requests.get(url, headers=headers, params=extra_params)
Expand All @@ -71,23 +72,26 @@ def openai_get_model_list(
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
response = response.json()
if response:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
response = response.json()
if response:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
response = response.json()
if response:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
Expand Down
2 changes: 1 addition & 1 deletion letta/local_llm/vllm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from letta.local_llm.settings.settings import get_completions_settings
from letta.local_llm.utils import count_tokens, post_json_auth_request

WEBUI_API_SUFFIX = "/v1/completions"
WEBUI_API_SUFFIX = "/completions"


def get_vllm_completion(endpoint, auth_type, auth_key, model, prompt, context_window, user, grammar=None):
Expand Down
81 changes: 72 additions & 9 deletions letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@

class Provider(BaseModel):

def list_llm_models(self):
def list_llm_models(self) -> List[LLMConfig]:
return []

def list_embedding_models(self):
def list_embedding_models(self) -> List[EmbeddingConfig]:
return []

def get_model_context_window(self, model_name: str):
pass
def get_model_context_window(self, model_name: str) -> Optional[int]:
raise NotImplementedError

def provider_tag(self) -> str:
"""String representation of the provider for display purposes"""
raise NotImplementedError


class LettaProvider(Provider):
Expand Down Expand Up @@ -162,7 +166,7 @@ def list_llm_models(self) -> List[LLMConfig]:
)
return configs

def get_model_context_window(self, model_name: str):
def get_model_context_window(self, model_name: str) -> Optional[int]:

import requests

Expand Down Expand Up @@ -310,7 +314,7 @@ def list_embedding_models(self):
)
return configs

def get_model_context_window(self, model_name: str):
def get_model_context_window(self, model_name: str) -> Optional[int]:
from letta.llm_api.google_ai import google_ai_get_model_context_window

return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
Expand Down Expand Up @@ -371,16 +375,75 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
)
return configs

def get_model_context_window(self, model_name: str):
def get_model_context_window(self, model_name: str) -> Optional[int]:
"""
This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
"""
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096)


class VLLMProvider(OpenAIProvider):
class VLLMChatCompletionsProvider(Provider):
"""vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""

# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
pass
name: str = "vllm"
base_url: str = Field(..., description="Base URL for the vLLM API.")

def list_llm_models(self) -> List[LLMConfig]:
# not supported with vLLM
from letta.llm_api.openai import openai_get_model_list

assert self.base_url, "base_url is required for vLLM provider"
response = openai_get_model_list(self.base_url, api_key=None)

configs = []
print(response)
for model in response["data"]:
configs.append(
LLMConfig(
model=model["id"],
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=model["max_model_len"],
)
)
return configs

def list_embedding_models(self) -> List[EmbeddingConfig]:
# not supported with vLLM
return []


class VLLMCompletionsProvider(Provider):
"""This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""

# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
name: str = "vllm"
base_url: str = Field(..., description="Base URL for the vLLM API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper)to use on vLLM /completions API.")

def list_llm_models(self) -> List[LLMConfig]:
# not supported with vLLM
from letta.llm_api.openai import openai_get_model_list

response = openai_get_model_list(self.base_url, api_key=None)

configs = []
for model in response["data"]:
configs.append(
LLMConfig(
model=model["id"],
model_endpoint_type="vllm",
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=model["max_model_len"],
)
)
return configs

def list_embedding_models(self) -> List[EmbeddingConfig]:
# not supported with vLLM
return []


class CohereProvider(OpenAIProvider):
Expand Down
7 changes: 7 additions & 0 deletions letta/schemas/embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,10 @@ def default_config(cls, model_name: Optional[str] = None, provider: Optional[str
)
else:
raise ValueError(f"Model {model_name} not supported.")

def pretty_print(self) -> str:
return (
f"{self.embedding_model}"
+ (f" [type={self.embedding_endpoint_type}]" if self.embedding_endpoint_type else "")
+ (f" [ip={self.embedding_endpoint}]" if self.embedding_endpoint else "")
)
7 changes: 7 additions & 0 deletions letta/schemas/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,10 @@ def default_config(cls, model_name: str):
)
else:
raise ValueError(f"Model {model_name} not supported.")

def pretty_print(self) -> str:
return (
f"{self.model}"
+ (f" [type={self.model_endpoint_type}]" if self.model_endpoint_type else "")
+ (f" [ip={self.model_endpoint}]" if self.model_endpoint else "")
)
20 changes: 16 additions & 4 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
OllamaProvider,
OpenAIProvider,
Provider,
VLLMProvider,
VLLMChatCompletionsProvider,
VLLMCompletionsProvider,
)
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
from letta.schemas.api_key import APIKey, APIKeyCreate
Expand Down Expand Up @@ -268,19 +269,30 @@ def __init__(
if model_settings.anthropic_api_key:
self._enabled_providers.append(AnthropicProvider(api_key=model_settings.anthropic_api_key))
if model_settings.ollama_base_url:
self._enabled_providers.append(OllamaProvider(base_url=model_settings.ollama_base_url))
if model_settings.vllm_base_url:
self._enabled_providers.append(VLLMProvider(base_url=model_settings.vllm_base_url))
self._enabled_providers.append(OllamaProvider(base_url=model_settings.ollama_base_url, api_key=None))
if model_settings.gemini_api_key:
self._enabled_providers.append(GoogleAIProvider(api_key=model_settings.gemini_api_key))
if model_settings.azure_api_key and model_settings.azure_base_url:
assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
self._enabled_providers.append(
AzureProvider(
api_key=model_settings.azure_api_key,
base_url=model_settings.azure_base_url,
api_version=model_settings.azure_api_version,
)
)
if model_settings.vllm_api_base:
# vLLM exposes both a /chat/completions and a /completions endpoint
self._enabled_providers.append(
VLLMCompletionsProvider(
base_url=model_settings.vllm_api_base,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
)
# NOTE: to use the /chat/completions endpoint, you need to specify extra flags on vLLM startup
# see: https://docs.vllm.ai/en/latest/getting_started/examples/openai_chat_completion_client_with_tools.html
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
self._enabled_providers.append(VLLMChatCompletionsProvider(base_url=model_settings.vllm_api_base))

def save_agents(self):
"""Saves all the agents that are in the in-memory object store"""
Expand Down
10 changes: 8 additions & 2 deletions letta/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict

from letta.local_llm.constants import DEFAULT_WRAPPER_NAME


class ModelSettings(BaseSettings):

# env_prefix='my_prefix_'

# when we use /completions APIs (instead of /chat/completions), we need to specify a model wrapper
# the "model wrapper" is responsible for prompt formatting and function calling parsing
default_prompt_formatter: str = DEFAULT_WRAPPER_NAME

# openai
openai_api_key: Optional[str] = None
openai_api_base: Optional[str] = "https://api.openai.com/v1"
openai_api_base: str = "https://api.openai.com/v1"

# groq
groq_api_key: Optional[str] = None
Expand All @@ -31,7 +37,7 @@ class ModelSettings(BaseSettings):
gemini_api_key: Optional[str] = None

# vLLM
vllm_base_url: Optional[str] = None
vllm_api_base: Optional[str] = None

# openllm
openllm_auth_type: Optional[str] = None
Expand Down
8 changes: 8 additions & 0 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,11 @@ def test_googleai():
print(models)

provider.list_embedding_models()


# def test_vllm():
# provider = VLLMProvider(base_url=os.getenv("VLLM_API_BASE"))
# models = provider.list_llm_models()
# print(models)
#
# provider.list_embedding_models()
Loading