Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add LMStudioProvider #1876

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions letta/providers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import List, Optional

from pydantic import BaseModel, Field, model_validator
Expand All @@ -10,6 +11,7 @@
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.utils import smart_urljoin


class Provider(BaseModel):
Expand Down Expand Up @@ -512,4 +514,78 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:


class CohereProvider(OpenAIProvider):
# TODO(matt)
pass


class LMStudioCompletionsProvider(Provider):
"""LMStudio server via /completions API (not /chat/completions)

See: https://lmstudio.ai/docs/basics/server#openai-like-api-endpoints
"""

name: str = "lmstudio"
base_url: str = Field(..., description="Base URL for the LMStudio API.")
default_prompt_formatter: str = Field(
..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API."
)

def list_llm_models(self) -> List[LLMConfig]:
# not supported with vLLM
from letta.llm_api.openai import openai_get_model_list

model_path = smart_urljoin(self.base_url, "api/v0/")
response = openai_get_model_list(model_path, api_key=None)

if "data" not in response:
warnings.warn(f"LMStudio returned an unexpected response: {response}")
return []

configs = []
for model in response["data"]:
# LMStudio's backend has a type field which can be "llm" or "embedding"
if "type" in model and model["type"] == "llm":
if "max_model_len" not in model:
warnings.warn(f"LMStudio model is missing max_model_len field: {model}")
if "id" not in model:
warnings.warn(f"LMStudio model is missing id field: {model}")
configs.append(
LLMConfig(
model=model["id"],
model_endpoint_type="lmstudio",
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=model["max_context_length"],
)
)
return configs

def list_embedding_models(self) -> List[EmbeddingConfig]:
from letta.llm_api.openai import openai_get_model_list

model_path = smart_urljoin(self.base_url, "api/v0/")
response = openai_get_model_list(model_path, api_key=None)

if "data" not in response:
warnings.warn(f"LMStudio returned an unexpected response: {response}")
return []

configs = []
for model in response["data"]:
# LMStudio's backend has a type field which can be "llm" or "embedding"
if "type" in model and model["type"] == "embeddings":
if "max_context_length" not in model:
warnings.warn(f"LMStudio model is missing max_model_len field: {model}")
if "id" not in model:
warnings.warn(f"LMStudio model is missing id field: {model}")
configs.append(
EmbeddingConfig(
embedding_model=model["id"],
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=model["max_context_length"],
embedding_chunk_size=300,
)
)

return configs
8 changes: 8 additions & 0 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
GoogleAIProvider,
GroqProvider,
LettaProvider,
LMStudioCompletionsProvider,
OllamaProvider,
OpenAIProvider,
Provider,
Expand Down Expand Up @@ -301,6 +302,13 @@ def __init__(
)
if model_settings.groq_api_key:
self._enabled_providers.append(GroqProvider(api_key=model_settings.groq_api_key))
if model_settings.lmstudio_base_url:
self._enabled_providers.append(
LMStudioCompletionsProvider(
base_url=model_settings.lmstudio_base_url,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
)
if model_settings.vllm_api_base:
# vLLM exposes both a /chat/completions and a /completions endpoint
self._enabled_providers.append(
Expand Down
3 changes: 3 additions & 0 deletions letta/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class ModelSettings(BaseSettings):
# ollama
ollama_base_url: Optional[str] = None

# lmstudio base url
lmstudio_base_url: Optional[str] = None

# azure
azure_api_key: Optional[str] = None
azure_base_url: Optional[str] = None
Expand Down
Loading