Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Openapi #287

Merged
merged 5 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions modelscope_agent/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ def chat(self,
assert len(messages) > 0, 'messages list must not be empty'

if stream:
return self._chat_stream(messages, stop=stop, **kwargs)
return self._chat_stream(
messages=messages, stop=stop, prompt=prompt, **kwargs)
else:
return self._chat_no_stream(messages, stop=stop, **kwargs)
return self._chat_no_stream(
messages=messages, stop=stop, prompt=prompt, **kwargs)

@retry(max_retries=3, delay_seconds=0.5)
def chat_with_functions(self,
Expand Down
85 changes: 74 additions & 11 deletions modelscope_agent/llm/openai.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,37 @@
import os
from typing import Dict, Iterator, List, Optional
from typing import Dict, Iterator, List, Optional, Union

import openai
from modelscope_agent.llm.base import BaseChatModel, register_llm
from modelscope_agent.utils.retry import retry
from openai import OpenAI


@register_llm('openai')
class OpenAi(BaseChatModel):

def __init__(self, model: str, model_server: str, **kwargs):
def __init__(self,
model: str,
model_server: str,
is_chat: bool = True,
is_function_call: Optional[bool] = None,
support_stream: Optional[bool] = None,
**kwargs):
super().__init__(model, model_server)

openai.api_base = kwargs.get('api_base',
'https://api.openai.com/v1').strip()
openai.api_key = kwargs.get(
'api_key', os.getenv('OPENAI_API_KEY', default='EMPTY')).strip()
api_base = kwargs.get('api_base', 'https://api.openai.com/v1').strip()
api_key = kwargs.get('api_key',
os.getenv('OPENAI_API_KEY',
default='EMPTY')).strip()
self.client = OpenAI(api_key=api_key, base_url=api_base)
self.is_function_call = is_function_call
self.is_chat = is_chat
self.support_stream = support_stream

def _chat_stream(self,
messages: List[Dict],
stop: Optional[List[str]] = None,
**kwargs) -> Iterator[str]:
response = openai.ChatCompletion.create(
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
stop=stop,
Expand All @@ -35,7 +46,7 @@ def _chat_no_stream(self,
messages: List[Dict],
stop: Optional[List[str]] = None,
**kwargs) -> str:
response = openai.ChatCompletion.create(
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
stop=stop,
Expand All @@ -44,18 +55,70 @@ def _chat_no_stream(self,
# TODO: error handling
return response.choices[0].message.content

def support_function_calling(self):
if self.is_function_call is None:
return super().support_function_calling()
else:
return self.is_function_call

def support_raw_prompt(self) -> bool:
if self.is_chat is None:
return super().support_raw_prompt()
else:
# if not chat, then prompt
return not self.is_chat

@retry(max_retries=3, delay_seconds=0.5)
def chat(self,
prompt: Optional[str] = None,
messages: Optional[List[Dict]] = None,
stop: Optional[List[str]] = None,
stream: bool = False,
**kwargs) -> Union[str, Iterator[str]]:
if isinstance(self.support_stream, bool):
stream = self.support_stream
if self.support_raw_prompt():
return self.chat_with_raw_prompt(
prompt=prompt, stream=stream, stop=stop, **kwargs)
if not messages and prompt and isinstance(prompt, str):
messages = [{'role': 'user', 'content': prompt}]
return super().chat(
messages=messages, stop=stop, stream=stream, **kwargs)

def _out_generator(self, response):
for chunk in response:
if hasattr(chunk.choices[0], 'text'):
yield chunk.choices[0].text

def chat_with_raw_prompt(self,
prompt: str,
stream: bool = True,
**kwargs) -> str:
max_tokens = kwargs.get('max_tokens', 2000)
response = self.client.completions.create(
model=self.model,
prompt=prompt,
stream=stream,
max_tokens=max_tokens)

# TODO: error handling
if stream:
return self._out_generator(response)
else:
return response.choices[0].text

def chat_with_functions(self,
messages: List[Dict],
functions: Optional[List[Dict]] = None,
**kwargs) -> Dict:
if functions:
response = openai.ChatCompletion.create(
response = self.client.completions.create(
model=self.model,
messages=messages,
functions=functions,
**kwargs)
else:
response = openai.ChatCompletion.create(
response = self.client.completions.create(
model=self.model, messages=messages, **kwargs)
# TODO: error handling
return response.choices[0].message
Loading