Skip to content

Commit

Permalink
fix: Custom OpenAI endpoint support (#203)
Browse files Browse the repository at this point in the history
* Custom OpenAI endpoint support

* Custom OpenAI endpoint support

* docstring

---------

Co-authored-by: slobentanzer <sebastian.lobentanzer@gmail.com>
  • Loading branch information
winternewt and slobentanzer authored Sep 19, 2024
1 parent f52a2e0 commit d3df12f
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions biochatter/llm_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,8 +1167,8 @@ def _primary_query(self):
as context. Correct the response if necessary.
Returns:
tuple: A tuple containing the response from the Anthropic API and the
token usage.
tuple: A tuple containing the response from the Anthropic API and
the token usage.
"""
try:
history = self._create_history()
Expand Down Expand Up @@ -1319,6 +1319,7 @@ def __init__(
prompts: dict,
correct: bool = False,
split_correction: bool = False,
base_url: str = None,
):
"""
Connect to OpenAI's GPT API and set up a conversation with the user.
Expand All @@ -1333,14 +1334,17 @@ def __init__(
split_correction (bool): Whether to correct the model output by
splitting the output into sentences and correcting each
sentence individually.
base_url (str): Optional OpenAI base_url value to use custom
endpoint URL instead of default
"""
super().__init__(
model_name=model_name,
prompts=prompts,
correct=correct,
split_correction=split_correction,
)

self.base_url = base_url
self.ca_model_name = "gpt-3.5-turbo"
# TODO make accessible by drop-down

Expand All @@ -1359,6 +1363,7 @@ def set_api_key(self, api_key: str, user: str) -> bool:
"""
client = openai.OpenAI(
api_key=api_key,
base_url=self.base_url,
)
self.user = user

Expand All @@ -1368,11 +1373,13 @@ def set_api_key(self, api_key: str, user: str) -> bool:
model_name=self.model_name,
temperature=0,
openai_api_key=api_key,
base_url=self.base_url,
)
self.ca_chat = ChatOpenAI(
model_name=self.ca_model_name,
temperature=0,
openai_api_key=api_key,
base_url=self.base_url,
)
if user == "community":
self.usage_stats = get_stats(user=user)
Expand Down

0 comments on commit d3df12f

Please sign in to comment.