-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Openapi #287
Openapi #287
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
import os | ||
from typing import Dict, Iterator, List, Optional | ||
|
||
import openai | ||
from modelscope_agent.llm.base import BaseChatModel, register_llm | ||
from openai import OpenAI | ||
|
||
|
||
@register_llm('openai') | ||
|
@@ -11,16 +11,17 @@ class OpenAi(BaseChatModel): | |
def __init__(self, model: str, model_server: str, **kwargs): | ||
super().__init__(model, model_server) | ||
|
||
openai.api_base = kwargs.get('api_base', | ||
'https://api.openai.com/v1').strip() | ||
openai.api_key = kwargs.get( | ||
'api_key', os.getenv('OPENAI_API_KEY', default='EMPTY')).strip() | ||
api_base = kwargs.get('api_base', 'https://api.openai.com/v1').strip() | ||
api_key = kwargs.get('api_key', | ||
os.getenv('OPENAI_API_KEY', | ||
default='EMPTY')).strip() | ||
self.client = OpenAI(api_key=api_key, base_url=api_base) | ||
|
||
def _chat_stream(self, | ||
messages: List[Dict], | ||
stop: Optional[List[str]] = None, | ||
**kwargs) -> Iterator[str]: | ||
response = openai.ChatCompletion.create( | ||
response = self.client.completions.create( | ||
model=self.model, | ||
messages=messages, | ||
stop=stop, | ||
|
@@ -35,7 +36,7 @@ def _chat_no_stream(self, | |
messages: List[Dict], | ||
stop: Optional[List[str]] = None, | ||
**kwargs) -> str: | ||
response = openai.ChatCompletion.create( | ||
response = self.client.completions.create( | ||
model=self.model, | ||
messages=messages, | ||
stop=stop, | ||
|
@@ -49,13 +50,49 @@ def chat_with_functions(self, | |
functions: Optional[List[Dict]] = None, | ||
**kwargs) -> Dict: | ||
if functions: | ||
response = openai.ChatCompletion.create( | ||
response = self.client.completions.create( | ||
model=self.model, | ||
messages=messages, | ||
functions=functions, | ||
**kwargs) | ||
else: | ||
response = openai.ChatCompletion.create( | ||
response = self.client.completions.create( | ||
model=self.model, messages=messages, **kwargs) | ||
# TODO: error handling | ||
return response.choices[0].message | ||
|
||
|
||
@register_llm('openapi') | ||
class OpenAPILocal(BaseChatModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. openapilocal这个名字是指本地起一个openai的模型?还是调用本地vllm?后者的话名字要不要换? |
||
|
||
def __init__(self, model: str, model_server: str, **kwargs): | ||
super().__init__(model, model_server) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 用户能用这个跑起来llama么?还是只能openapi? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是用这个跑,需要本地先通过vllm拉起一个openai格式的api |
||
openai_api_key = 'EMPTY' | ||
openai_api_base = 'http://localhost:8000/v1' | ||
self.client = OpenAI( | ||
api_key=openai_api_key, | ||
base_url=openai_api_base, | ||
) | ||
|
||
def _chat_stream(self, prompt: str, **kwargs) -> Iterator[str]: | ||
response = self.client.completions.create( | ||
model=self.model, | ||
prompt=prompt, | ||
stream=True, | ||
) | ||
# TODO: error handling | ||
for chunk in response: | ||
if hasattr(chunk.choices[0], 'text'): | ||
yield chunk.choices[0].text | ||
|
||
def _chat_no_stream(self, prompt: str, **kwargs) -> str: | ||
response = self.client.completions.create( | ||
model=self.model, | ||
prompt=prompt, | ||
stream=False, | ||
) | ||
# TODO: error handling | ||
return response.choices[0].message.content | ||
|
||
def support_function_calling(self) -> bool: | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个llm名字和上面重复了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面的是openai,这个是openapi