Skip to content

Commit 8a7451e

Browse files
suluyanaskyline2006suluyan
authored
Openapi (#287)
* feat: openapi * fix pre-commit * fix bugs --------- Co-authored-by: skyline2006 <skyline2006@163.com> Co-authored-by: suluyan <suluyan.sly@alibaba-inc.com>
1 parent 5d06ba4 commit 8a7451e

File tree

2 files changed

+78
-13
lines changed

2 files changed

+78
-13
lines changed

modelscope_agent/llm/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,11 @@ def chat(self,
7474
assert len(messages) > 0, 'messages list must not be empty'
7575

7676
if stream:
77-
return self._chat_stream(messages, stop=stop, **kwargs)
77+
return self._chat_stream(
78+
messages=messages, stop=stop, prompt=prompt, **kwargs)
7879
else:
79-
return self._chat_no_stream(messages, stop=stop, **kwargs)
80+
return self._chat_no_stream(
81+
messages=messages, stop=stop, prompt=prompt, **kwargs)
8082

8183
@retry(max_retries=3, delay_seconds=0.5)
8284
def chat_with_functions(self,

modelscope_agent/llm/openai.py

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,37 @@
11
import os
2-
from typing import Dict, Iterator, List, Optional
2+
from typing import Dict, Iterator, List, Optional, Union
33

4-
import openai
54
from modelscope_agent.llm.base import BaseChatModel, register_llm
5+
from modelscope_agent.utils.retry import retry
6+
from openai import OpenAI
67

78

89
@register_llm('openai')
910
class OpenAi(BaseChatModel):
1011

11-
def __init__(self, model: str, model_server: str, **kwargs):
12+
def __init__(self,
13+
model: str,
14+
model_server: str,
15+
is_chat: bool = True,
16+
is_function_call: Optional[bool] = None,
17+
support_stream: Optional[bool] = None,
18+
**kwargs):
1219
super().__init__(model, model_server)
1320

14-
openai.api_base = kwargs.get('api_base',
15-
'https://api.openai.com/v1').strip()
16-
openai.api_key = kwargs.get(
17-
'api_key', os.getenv('OPENAI_API_KEY', default='EMPTY')).strip()
21+
api_base = kwargs.get('api_base', 'https://api.openai.com/v1').strip()
22+
api_key = kwargs.get('api_key',
23+
os.getenv('OPENAI_API_KEY',
24+
default='EMPTY')).strip()
25+
self.client = OpenAI(api_key=api_key, base_url=api_base)
26+
self.is_function_call = is_function_call
27+
self.is_chat = is_chat
28+
self.support_stream = support_stream
1829

1930
def _chat_stream(self,
2031
messages: List[Dict],
2132
stop: Optional[List[str]] = None,
2233
**kwargs) -> Iterator[str]:
23-
response = openai.ChatCompletion.create(
34+
response = self.client.chat.completions.create(
2435
model=self.model,
2536
messages=messages,
2637
stop=stop,
@@ -35,7 +46,7 @@ def _chat_no_stream(self,
3546
messages: List[Dict],
3647
stop: Optional[List[str]] = None,
3748
**kwargs) -> str:
38-
response = openai.ChatCompletion.create(
49+
response = self.client.chat.completions.create(
3950
model=self.model,
4051
messages=messages,
4152
stop=stop,
@@ -44,18 +55,70 @@ def _chat_no_stream(self,
4455
# TODO: error handling
4556
return response.choices[0].message.content
4657

58+
def support_function_calling(self):
59+
if self.is_function_call is None:
60+
return super().support_function_calling()
61+
else:
62+
return self.is_function_call
63+
64+
def support_raw_prompt(self) -> bool:
65+
if self.is_chat is None:
66+
return super().support_raw_prompt()
67+
else:
68+
# if not chat, then prompt
69+
return not self.is_chat
70+
71+
@retry(max_retries=3, delay_seconds=0.5)
72+
def chat(self,
73+
prompt: Optional[str] = None,
74+
messages: Optional[List[Dict]] = None,
75+
stop: Optional[List[str]] = None,
76+
stream: bool = False,
77+
**kwargs) -> Union[str, Iterator[str]]:
78+
if isinstance(self.support_stream, bool):
79+
stream = self.support_stream
80+
if self.support_raw_prompt():
81+
return self.chat_with_raw_prompt(
82+
prompt=prompt, stream=stream, stop=stop, **kwargs)
83+
if not messages and prompt and isinstance(prompt, str):
84+
messages = [{'role': 'user', 'content': prompt}]
85+
return super().chat(
86+
messages=messages, stop=stop, stream=stream, **kwargs)
87+
88+
def _out_generator(self, response):
89+
for chunk in response:
90+
if hasattr(chunk.choices[0], 'text'):
91+
yield chunk.choices[0].text
92+
93+
def chat_with_raw_prompt(self,
94+
prompt: str,
95+
stream: bool = True,
96+
**kwargs) -> str:
97+
max_tokens = kwargs.get('max_tokens', 2000)
98+
response = self.client.completions.create(
99+
model=self.model,
100+
prompt=prompt,
101+
stream=stream,
102+
max_tokens=max_tokens)
103+
104+
# TODO: error handling
105+
if stream:
106+
return self._out_generator(response)
107+
else:
108+
return response.choices[0].text
109+
47110
def chat_with_functions(self,
48111
messages: List[Dict],
49112
functions: Optional[List[Dict]] = None,
50113
**kwargs) -> Dict:
51114
if functions:
52-
response = openai.ChatCompletion.create(
115+
response = self.client.completions.create(
53116
model=self.model,
54117
messages=messages,
55118
functions=functions,
56119
**kwargs)
57120
else:
58-
response = openai.ChatCompletion.create(
121+
response = self.client.completions.create(
59122
model=self.model, messages=messages, **kwargs)
60123
# TODO: error handling
61124
return response.choices[0].message

0 commit comments

Comments
 (0)