diff --git a/letta/providers.py b/letta/providers.py index 6fa98327f3..10821a96d6 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -1,3 +1,4 @@ +import warnings from typing import List, Optional from pydantic import BaseModel, Field, model_validator @@ -10,6 +11,7 @@ from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig +from letta.utils import smart_urljoin class Provider(BaseModel): @@ -512,4 +514,78 @@ def list_embedding_models(self) -> List[EmbeddingConfig]: class CohereProvider(OpenAIProvider): + # TODO(matt) pass + + +class LMStudioCompletionsProvider(Provider): + """LMStudio server via /completions API (not /chat/completions) + + See: https://lmstudio.ai/docs/basics/server#openai-like-api-endpoints + """ + + name: str = "lmstudio" + base_url: str = Field(..., description="Base URL for the LMStudio API.") + default_prompt_formatter: str = Field( + ..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API." + ) + + def list_llm_models(self) -> List[LLMConfig]: + # not supported with vLLM + from letta.llm_api.openai import openai_get_model_list + + model_path = smart_urljoin(self.base_url, "api/v0/") + response = openai_get_model_list(model_path, api_key=None) + + if "data" not in response: + warnings.warn(f"LMStudio returned an unexpected response: {response}") + return [] + + configs = [] + for model in response["data"]: + # LMStudio's backend has a type field which can be "llm" or "embedding" + if "type" in model and model["type"] == "llm": + if "max_model_len" not in model: + warnings.warn(f"LMStudio model is missing max_model_len field: {model}") + if "id" not in model: + warnings.warn(f"LMStudio model is missing id field: {model}") + configs.append( + LLMConfig( + model=model["id"], + model_endpoint_type="lmstudio", + model_endpoint=self.base_url, + model_wrapper=self.default_prompt_formatter, + context_window=model["max_context_length"], + ) + ) + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + from letta.llm_api.openai import openai_get_model_list + + model_path = smart_urljoin(self.base_url, "api/v0/") + response = openai_get_model_list(model_path, api_key=None) + + if "data" not in response: + warnings.warn(f"LMStudio returned an unexpected response: {response}") + return [] + + configs = [] + for model in response["data"]: + # LMStudio's backend has a type field which can be "llm" or "embedding" + if "type" in model and model["type"] == "embeddings": + if "max_context_length" not in model: + warnings.warn(f"LMStudio model is missing max_model_len field: {model}") + if "id" not in model: + warnings.warn(f"LMStudio model is missing id field: {model}") + configs.append( + EmbeddingConfig( + embedding_model=model["id"], + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=model["max_context_length"], + embedding_chunk_size=300, + ) + ) + + return configs diff --git a/letta/server/server.py b/letta/server/server.py index 83ba2a96a0..f0f30ef777 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -45,6 +45,7 @@ GoogleAIProvider, GroqProvider, LettaProvider, + LMStudioCompletionsProvider, OllamaProvider, OpenAIProvider, Provider, @@ -301,6 +302,13 @@ def __init__( ) if model_settings.groq_api_key: self._enabled_providers.append(GroqProvider(api_key=model_settings.groq_api_key)) + if model_settings.lmstudio_base_url: + self._enabled_providers.append( + LMStudioCompletionsProvider( + base_url=model_settings.lmstudio_base_url, + default_prompt_formatter=model_settings.default_prompt_formatter, + ) + ) if model_settings.vllm_api_base: # vLLM exposes both a /chat/completions and a /completions endpoint self._enabled_providers.append( diff --git a/letta/settings.py b/letta/settings.py index c5e7ee3bc6..c78a0d0ce1 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -32,6 +32,9 @@ class ModelSettings(BaseSettings): # ollama ollama_base_url: Optional[str] = None + # lmstudio base url + lmstudio_base_url: Optional[str] = None + # azure azure_api_key: Optional[str] = None azure_base_url: Optional[str] = None