diff --git a/biochatter/llm_connect.py b/biochatter/llm_connect.py index 9c2bba6c..e375079f 100644 --- a/biochatter/llm_connect.py +++ b/biochatter/llm_connect.py @@ -1167,8 +1167,8 @@ def _primary_query(self): as context. Correct the response if necessary. Returns: - tuple: A tuple containing the response from the Anthropic API and the - token usage. + tuple: A tuple containing the response from the Anthropic API and + the token usage. """ try: history = self._create_history() @@ -1319,6 +1319,7 @@ def __init__( prompts: dict, correct: bool = False, split_correction: bool = False, + base_url: str = None, ): """ Connect to OpenAI's GPT API and set up a conversation with the user. @@ -1333,6 +1334,9 @@ def __init__( split_correction (bool): Whether to correct the model output by splitting the output into sentences and correcting each sentence individually. + + base_url (str): Optional OpenAI base_url value to use custom + endpoint URL instead of default """ super().__init__( model_name=model_name, @@ -1340,7 +1344,7 @@ def __init__( correct=correct, split_correction=split_correction, ) - + self.base_url = base_url self.ca_model_name = "gpt-3.5-turbo" # TODO make accessible by drop-down @@ -1359,6 +1363,7 @@ def set_api_key(self, api_key: str, user: str) -> bool: """ client = openai.OpenAI( api_key=api_key, + base_url=self.base_url, ) self.user = user @@ -1368,11 +1373,13 @@ def set_api_key(self, api_key: str, user: str) -> bool: model_name=self.model_name, temperature=0, openai_api_key=api_key, + base_url=self.base_url, ) self.ca_chat = ChatOpenAI( model_name=self.ca_model_name, temperature=0, openai_api_key=api_key, + base_url=self.base_url, ) if user == "community": self.usage_stats = get_stats(user=user)