Skip to content

Commit

Permalink
feat: add VLLMProvider (#1866)
Browse files Browse the repository at this point in the history
Co-authored-by: cpacker <packercharles@gmail.com>
  • Loading branch information
sarahwooders and cpacker authored Oct 11, 2024
1 parent 30ff274 commit 32fbd71
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 37 deletions.
20 changes: 2 additions & 18 deletions letta/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL
from letta.log import get_logger
from letta.metadata import MetadataStore
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import OptionState
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory, Memory
from letta.server.server import logger as server_logger

Expand Down Expand Up @@ -235,12 +233,7 @@ def run(
# choose from list of llm_configs
llm_configs = client.list_llm_configs()
llm_options = [llm_config.model for llm_config in llm_configs]

# TODO move into LLMConfig as a class method?
def prettify_llm_config(llm_config: LLMConfig) -> str:
return f"{llm_config.model}" + f" ({llm_config.model_endpoint})" if llm_config.model_endpoint else ""

llm_choices = [questionary.Choice(title=prettify_llm_config(llm_config), value=llm_config) for llm_config in llm_configs]
llm_choices = [questionary.Choice(title=llm_config.pretty_print(), value=llm_config) for llm_config in llm_configs]

# select model
if len(llm_options) == 0:
Expand All @@ -255,17 +248,8 @@ def prettify_llm_config(llm_config: LLMConfig) -> str:
embedding_configs = client.list_embedding_configs()
embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs]

# TODO move into EmbeddingConfig as a class method?
def prettify_embed_config(embedding_config: EmbeddingConfig) -> str:
return (
f"{embedding_config.embedding_model}" + f" ({embedding_config.embedding_endpoint})"
if embedding_config.embedding_endpoint
else ""
)

embedding_choices = [
questionary.Choice(title=prettify_embed_config(embedding_config), value=embedding_config)
for embedding_config in embedding_configs
questionary.Choice(title=embedding_config.pretty_print(), value=embedding_config) for embedding_config in embedding_configs
]

# select model
Expand Down
4 changes: 4 additions & 0 deletions letta/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)

except requests.exceptions.HTTPError as http_err:

if not hasattr(http_err, "response") or not http_err.response:
raise

# Retry on specified errors
if http_err.response.status_code in error_codes:
# Increment retries
Expand Down
10 changes: 7 additions & 3 deletions letta/llm_api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def openai_get_model_list(
headers["Authorization"] = f"Bearer {api_key}"

printd(f"Sending request to {url}")
response = None
try:
# TODO add query param "tool" to be true
response = requests.get(url, headers=headers, params=extra_params)
Expand All @@ -71,23 +72,26 @@ def openai_get_model_list(
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
response = response.json()
if response:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
response = response.json()
if response:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
response = response.json()
if response:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
Expand Down
2 changes: 1 addition & 1 deletion letta/local_llm/vllm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from letta.local_llm.settings.settings import get_completions_settings
from letta.local_llm.utils import count_tokens, post_json_auth_request

WEBUI_API_SUFFIX = "/v1/completions"
WEBUI_API_SUFFIX = "/completions"


def get_vllm_completion(endpoint, auth_type, auth_key, model, prompt, context_window, user, grammar=None):
Expand Down
81 changes: 72 additions & 9 deletions letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@

class Provider(BaseModel):

def list_llm_models(self):
def list_llm_models(self) -> List[LLMConfig]:
return []

def list_embedding_models(self):
def list_embedding_models(self) -> List[EmbeddingConfig]:
return []

def get_model_context_window(self, model_name: str):
pass
def get_model_context_window(self, model_name: str) -> Optional[int]:
raise NotImplementedError

def provider_tag(self) -> str:
"""String representation of the provider for display purposes"""
raise NotImplementedError


class LettaProvider(Provider):
Expand Down Expand Up @@ -162,7 +166,7 @@ def list_llm_models(self) -> List[LLMConfig]:
)
return configs

def get_model_context_window(self, model_name: str):
def get_model_context_window(self, model_name: str) -> Optional[int]:

import requests

Expand Down Expand Up @@ -310,7 +314,7 @@ def list_embedding_models(self):
)
return configs

def get_model_context_window(self, model_name: str):
def get_model_context_window(self, model_name: str) -> Optional[int]:
from letta.llm_api.google_ai import google_ai_get_model_context_window

return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
Expand Down Expand Up @@ -371,16 +375,75 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
)
return configs

def get_model_context_window(self, model_name: str):
def get_model_context_window(self, model_name: str) -> Optional[int]:
"""
This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
"""
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096)


class VLLMProvider(OpenAIProvider):
class VLLMChatCompletionsProvider(Provider):
"""vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""

# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
pass
name: str = "vllm"
base_url: str = Field(..., description="Base URL for the vLLM API.")

def list_llm_models(self) -> List[LLMConfig]:
# not supported with vLLM
from letta.llm_api.openai import openai_get_model_list

assert self.base_url, "base_url is required for vLLM provider"
response = openai_get_model_list(self.base_url, api_key=None)

configs = []
print(response)
for model in response["data"]:
configs.append(
LLMConfig(
model=model["id"],
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=model["max_model_len"],
)
)
return configs

def list_embedding_models(self) -> List[EmbeddingConfig]:
# not supported with vLLM
return []


class VLLMCompletionsProvider(Provider):
"""This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""

# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
name: str = "vllm"
base_url: str = Field(..., description="Base URL for the vLLM API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper)to use on vLLM /completions API.")

def list_llm_models(self) -> List[LLMConfig]:
# not supported with vLLM
from letta.llm_api.openai import openai_get_model_list

response = openai_get_model_list(self.base_url, api_key=None)

configs = []
for model in response["data"]:
configs.append(
LLMConfig(
model=model["id"],
model_endpoint_type="vllm",
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=model["max_model_len"],
)
)
return configs

def list_embedding_models(self) -> List[EmbeddingConfig]:
# not supported with vLLM
return []


class CohereProvider(OpenAIProvider):
Expand Down
7 changes: 7 additions & 0 deletions letta/schemas/embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,10 @@ def default_config(cls, model_name: Optional[str] = None, provider: Optional[str
)
else:
raise ValueError(f"Model {model_name} not supported.")

def pretty_print(self) -> str:
return (
f"{self.embedding_model}"
+ (f" [type={self.embedding_endpoint_type}]" if self.embedding_endpoint_type else "")
+ (f" [ip={self.embedding_endpoint}]" if self.embedding_endpoint else "")
)
7 changes: 7 additions & 0 deletions letta/schemas/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,10 @@ def default_config(cls, model_name: str):
)
else:
raise ValueError(f"Model {model_name} not supported.")

def pretty_print(self) -> str:
return (
f"{self.model}"
+ (f" [type={self.model_endpoint_type}]" if self.model_endpoint_type else "")
+ (f" [ip={self.model_endpoint}]" if self.model_endpoint else "")
)
20 changes: 16 additions & 4 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
OllamaProvider,
OpenAIProvider,
Provider,
VLLMProvider,
VLLMChatCompletionsProvider,
VLLMCompletionsProvider,
)
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
from letta.schemas.api_key import APIKey, APIKeyCreate
Expand Down Expand Up @@ -244,19 +245,30 @@ def __init__(
if model_settings.anthropic_api_key:
self._enabled_providers.append(AnthropicProvider(api_key=model_settings.anthropic_api_key))
if model_settings.ollama_base_url:
self._enabled_providers.append(OllamaProvider(base_url=model_settings.ollama_base_url))
if model_settings.vllm_base_url:
self._enabled_providers.append(VLLMProvider(base_url=model_settings.vllm_base_url))
self._enabled_providers.append(OllamaProvider(base_url=model_settings.ollama_base_url, api_key=None))
if model_settings.gemini_api_key:
self._enabled_providers.append(GoogleAIProvider(api_key=model_settings.gemini_api_key))
if model_settings.azure_api_key and model_settings.azure_base_url:
assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
self._enabled_providers.append(
AzureProvider(
api_key=model_settings.azure_api_key,
base_url=model_settings.azure_base_url,
api_version=model_settings.azure_api_version,
)
)
if model_settings.vllm_api_base:
# vLLM exposes both a /chat/completions and a /completions endpoint
self._enabled_providers.append(
VLLMCompletionsProvider(
base_url=model_settings.vllm_api_base,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
)
# NOTE: to use the /chat/completions endpoint, you need to specify extra flags on vLLM startup
# see: https://docs.vllm.ai/en/latest/getting_started/examples/openai_chat_completion_client_with_tools.html
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
self._enabled_providers.append(VLLMChatCompletionsProvider(base_url=model_settings.vllm_api_base))

def save_agents(self):
"""Saves all the agents that are in the in-memory object store"""
Expand Down
10 changes: 8 additions & 2 deletions letta/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict

from letta.local_llm.constants import DEFAULT_WRAPPER_NAME


class ModelSettings(BaseSettings):

# env_prefix='my_prefix_'

# when we use /completions APIs (instead of /chat/completions), we need to specify a model wrapper
# the "model wrapper" is responsible for prompt formatting and function calling parsing
default_prompt_formatter: str = DEFAULT_WRAPPER_NAME

# openai
openai_api_key: Optional[str] = None
openai_api_base: Optional[str] = "https://api.openai.com/v1"
openai_api_base: str = "https://api.openai.com/v1"

# groq
groq_api_key: Optional[str] = None
Expand All @@ -34,7 +40,7 @@ class ModelSettings(BaseSettings):
gemini_api_key: Optional[str] = None

# vLLM
vllm_base_url: Optional[str] = None
vllm_api_base: Optional[str] = None

# openllm
openllm_auth_type: Optional[str] = None
Expand Down
8 changes: 8 additions & 0 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,11 @@ def test_googleai():
print(models)

provider.list_embedding_models()


# def test_vllm():
# provider = VLLMProvider(base_url=os.getenv("VLLM_API_BASE"))
# models = provider.list_llm_models()
# print(models)
#
# provider.list_embedding_models()

0 comments on commit 32fbd71

Please sign in to comment.