Skip to content

Commit

Permalink
fix: openai async (#585) bump:patch
Browse files Browse the repository at this point in the history
  • Loading branch information
taprosoft authored Dec 24, 2024
1 parent 95191f5 commit 5343d0d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
34 changes: 28 additions & 6 deletions libs/kotaemon/kotaemon/llms/chats/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ def openai_response(self, client, **kwargs):
"""Get the openai response"""
raise NotImplementedError

async def aopenai_response(self, client, **kwargs):
"""Get the openai response"""
raise NotImplementedError

def invoke(
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
) -> LLMInterface:
Expand All @@ -211,8 +215,10 @@ async def ainvoke(
) -> LLMInterface:
client = self.prepare_client(async_version=True)
input_messages = self.prepare_message(messages)
resp = await self.openai_response(
client, messages=input_messages, stream=False, **kwargs
resp = (
await self.aopenai_response(
client, messages=input_messages, stream=False, **kwargs
)
).dict()

return self.prepare_output(resp)
Expand Down Expand Up @@ -290,8 +296,7 @@ def prepare_client(self, async_version: bool = False):

return OpenAI(**params)

def openai_response(self, client, **kwargs):
"""Get the openai response"""
def prepare_params(self, **kwargs):
if "tools_pydantic" in kwargs:
kwargs.pop("tools_pydantic")

Expand All @@ -313,8 +318,17 @@ def openai_response(self, client, **kwargs):
params = {k: v for k, v in params_.items() if v is not None}
params.update(kwargs)

return params

def openai_response(self, client, **kwargs):
"""Get the openai response"""
params = self.prepare_params(**kwargs)
return client.chat.completions.create(**params)

async def aopenai_response(self, client, **kwargs):
params = self.prepare_params(**kwargs)
return await client.chat.completions.create(**params)


class AzureChatOpenAI(BaseChatOpenAI):
"""OpenAI chat model provided by Microsoft Azure"""
Expand Down Expand Up @@ -361,8 +375,7 @@ def prepare_client(self, async_version: bool = False):

return AzureOpenAI(**params)

def openai_response(self, client, **kwargs):
"""Get the openai response"""
def prepare_params(self, **kwargs):
if "tools_pydantic" in kwargs:
kwargs.pop("tools_pydantic")

Expand All @@ -384,4 +397,13 @@ def openai_response(self, client, **kwargs):
params = {k: v for k, v in params_.items() if v is not None}
params.update(kwargs)

return params

def openai_response(self, client, **kwargs):
"""Get the openai response"""
params = self.prepare_params(**kwargs)
return client.chat.completions.create(**params)

async def aopenai_response(self, client, **kwargs):
params = self.prepare_params(**kwargs)
return await client.chat.completions.create(**params)
2 changes: 1 addition & 1 deletion libs/ktem/ktem/pages/chat/chat_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def on_building_ui(self):
scale=20,
file_count="multiple",
placeholder=(
"Type a message, or search the @web, " "tag a file with @filename"
"Type a message, search the @web, or tag a file with @filename"
),
container=False,
show_label=False,
Expand Down

0 comments on commit 5343d0d

Please sign in to comment.