Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow passing base_url in openai section of config.toml #328

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ region_name = "us-east-1"
simple_model = "gpt-4o-mini"
complex_model = "gpt-4o-2024-08-06"
embeddings_model = "text-embedding-ada-002"
base_url = "https://api.openai.com/v1/" # for openai compatible apis

[ollama]
simple_model = "llama3.2"
Expand Down
29 changes: 29 additions & 0 deletions src/rai/rai/utils/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,39 @@ def on_vendor_change():
value=st.session_state["config"]["openai"]["embeddings_model"],
key="embeddings_model",
)

def on_openai_compatible_api_change():
st.session_state.use_openai_compatible_api = (
st.session_state.openai_compatible_api_checkbox
)

if "use_openai_compatible_api" not in st.session_state:
st.session_state.use_openai_compatible_api = False

use_openai_compatible_api = st.checkbox(
"Use OpenAI compatible API",
value=st.session_state.use_openai_compatible_api,
key="openai_compatible_api_checkbox",
on_change=on_openai_compatible_api_change,
)
st.session_state.use_openai_compatible_api = use_openai_compatible_api

if use_openai_compatible_api:
st.info(
"Used for OpenAI compatible endpoints, e.g. Ollama, vLLM... Make sure to specify `OPENAI_API_KEY` environment variable based on vendor's specification."
)
openai_api_base_url = st.text_input(
"OpenAI API base URL",
value=st.session_state["config"]["openai"]["base_url"],
key="openai_api_base_url",
)
else:
openai_api_base_url = st.session_state["config"]["openai"]["base_url"]
st.session_state.config["openai"] = {
"simple_model": simple_model,
"complex_model": complex_model,
"embeddings_model": embeddings_model,
"base_url": openai_api_base_url,
}

elif vendor == "aws":
Expand Down
24 changes: 18 additions & 6 deletions src/rai/rai/utils/model_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
import os
from dataclasses import dataclass
from typing import List, Literal
from typing import List, Literal, Optional, cast

import coloredlogs
import tomli
Expand Down Expand Up @@ -50,6 +50,11 @@ class OllamaConfig(ModelConfig):
base_url: str


@dataclass
class OpenAIConfig(ModelConfig):
base_url: str


@dataclass
class LangfuseConfig:
use_langfuse: bool
Expand All @@ -72,7 +77,7 @@ class TracingConfig:
class RAIConfig:
vendor: VendorConfig
aws: AWSConfig
openai: ModelConfig
openai: OpenAIConfig
ollama: OllamaConfig
tracing: TracingConfig

Expand All @@ -83,7 +88,7 @@ def load_config() -> RAIConfig:
return RAIConfig(
vendor=VendorConfig(**config_dict["vendor"]),
aws=AWSConfig(**config_dict["aws"]),
openai=ModelConfig(**config_dict["openai"]),
openai=OpenAIConfig(**config_dict["openai"]),
ollama=OllamaConfig(**config_dict["ollama"]),
tracing=TracingConfig(
project=config_dict["tracing"]["project"],
Expand All @@ -94,7 +99,9 @@ def load_config() -> RAIConfig:


def get_llm_model(
model_type: Literal["simple_model", "complex_model"], vendor: str = None, **kwargs
model_type: Literal["simple_model", "complex_model"],
vendor: Optional[str] = None,
**kwargs,
):
config = load_config()
if vendor is None:
Expand All @@ -106,14 +113,18 @@ def get_llm_model(
model_config = getattr(config, vendor)

model = getattr(model_config, model_type)
logger.info(f"Using LLM model: {vendor}-{model}")
logger.info(f"Initializing {model_type}: Vendor: {vendor}, Model: {model}")
if vendor == "openai":
from langchain_openai import ChatOpenAI

return ChatOpenAI(model=model, **kwargs)
model_config = cast(OpenAIConfig, model_config)

return ChatOpenAI(model=model, base_url=model_config.base_url, **kwargs)
elif vendor == "aws":
from langchain_aws import ChatBedrock

model_config = cast(AWSConfig, model_config)

return ChatBedrock(
model_id=model,
region_name=model_config.region_name,
Expand All @@ -122,6 +133,7 @@ def get_llm_model(
elif vendor == "ollama":
from langchain_ollama import ChatOllama

model_config = cast(OllamaConfig, model_config)
return ChatOllama(model=model, base_url=model_config.base_url, **kwargs)
else:
raise ValueError(f"Unknown LLM vendor: {vendor}")
Expand Down