11import os
2- from typing import Dict , Iterator , List , Optional
2+ from typing import Dict , Iterator , List , Optional , Union
33
4- import openai
54from modelscope_agent .llm .base import BaseChatModel , register_llm
5+ from modelscope_agent .utils .retry import retry
6+ from openai import OpenAI
67
78
89@register_llm ('openai' )
910class OpenAi (BaseChatModel ):
1011
11- def __init__ (self , model : str , model_server : str , ** kwargs ):
12+ def __init__ (self ,
13+ model : str ,
14+ model_server : str ,
15+ is_chat : bool = True ,
16+ is_function_call : Optional [bool ] = None ,
17+ support_stream : Optional [bool ] = None ,
18+ ** kwargs ):
1219 super ().__init__ (model , model_server )
1320
14- openai .api_base = kwargs .get ('api_base' ,
15- 'https://api.openai.com/v1' ).strip ()
16- openai .api_key = kwargs .get (
17- 'api_key' , os .getenv ('OPENAI_API_KEY' , default = 'EMPTY' )).strip ()
21+ api_base = kwargs .get ('api_base' , 'https://api.openai.com/v1' ).strip ()
22+ api_key = kwargs .get ('api_key' ,
23+ os .getenv ('OPENAI_API_KEY' ,
24+ default = 'EMPTY' )).strip ()
25+ self .client = OpenAI (api_key = api_key , base_url = api_base )
26+ self .is_function_call = is_function_call
27+ self .is_chat = is_chat
28+ self .support_stream = support_stream
1829
1930 def _chat_stream (self ,
2031 messages : List [Dict ],
2132 stop : Optional [List [str ]] = None ,
2233 ** kwargs ) -> Iterator [str ]:
23- response = openai . ChatCompletion .create (
34+ response = self . client . chat . completions .create (
2435 model = self .model ,
2536 messages = messages ,
2637 stop = stop ,
@@ -35,7 +46,7 @@ def _chat_no_stream(self,
3546 messages : List [Dict ],
3647 stop : Optional [List [str ]] = None ,
3748 ** kwargs ) -> str :
38- response = openai . ChatCompletion .create (
49+ response = self . client . chat . completions .create (
3950 model = self .model ,
4051 messages = messages ,
4152 stop = stop ,
@@ -44,18 +55,70 @@ def _chat_no_stream(self,
4455 # TODO: error handling
4556 return response .choices [0 ].message .content
4657
58+ def support_function_calling (self ):
59+ if self .is_function_call is None :
60+ return super ().support_function_calling ()
61+ else :
62+ return self .is_function_call
63+
64+ def support_raw_prompt (self ) -> bool :
65+ if self .is_chat is None :
66+ return super ().support_raw_prompt ()
67+ else :
68+ # if not chat, then prompt
69+ return not self .is_chat
70+
71+ @retry (max_retries = 3 , delay_seconds = 0.5 )
72+ def chat (self ,
73+ prompt : Optional [str ] = None ,
74+ messages : Optional [List [Dict ]] = None ,
75+ stop : Optional [List [str ]] = None ,
76+ stream : bool = False ,
77+ ** kwargs ) -> Union [str , Iterator [str ]]:
78+ if isinstance (self .support_stream , bool ):
79+ stream = self .support_stream
80+ if self .support_raw_prompt ():
81+ return self .chat_with_raw_prompt (
82+ prompt = prompt , stream = stream , stop = stop , ** kwargs )
83+ if not messages and prompt and isinstance (prompt , str ):
84+ messages = [{'role' : 'user' , 'content' : prompt }]
85+ return super ().chat (
86+ messages = messages , stop = stop , stream = stream , ** kwargs )
87+
88+ def _out_generator (self , response ):
89+ for chunk in response :
90+ if hasattr (chunk .choices [0 ], 'text' ):
91+ yield chunk .choices [0 ].text
92+
93+ def chat_with_raw_prompt (self ,
94+ prompt : str ,
95+ stream : bool = True ,
96+ ** kwargs ) -> str :
97+ max_tokens = kwargs .get ('max_tokens' , 2000 )
98+ response = self .client .completions .create (
99+ model = self .model ,
100+ prompt = prompt ,
101+ stream = stream ,
102+ max_tokens = max_tokens )
103+
104+ # TODO: error handling
105+ if stream :
106+ return self ._out_generator (response )
107+ else :
108+ return response .choices [0 ].text
109+
47110 def chat_with_functions (self ,
48111 messages : List [Dict ],
49112 functions : Optional [List [Dict ]] = None ,
50113 ** kwargs ) -> Dict :
51114 if functions :
52- response = openai . ChatCompletion .create (
115+ response = self . client . completions .create (
53116 model = self .model ,
54117 messages = messages ,
55118 functions = functions ,
56119 ** kwargs )
57120 else :
58- response = openai . ChatCompletion .create (
121+ response = self . client . completions .create (
59122 model = self .model , messages = messages , ** kwargs )
60123 # TODO: error handling
61124 return response .choices [0 ].message
0 commit comments