Skip to content

Commit

Permalink
[Fix] Fix GPT error with parallel calling (#40)
Browse files Browse the repository at this point in the history
* update gpt

* update

* fix

* update
  • Loading branch information
kennymckormick authored Jan 3, 2024
1 parent 9fb4b00 commit 5bbef87
Showing 1 changed file with 26 additions and 27 deletions.
53 changes: 26 additions & 27 deletions vlmeval/api/gpt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from ..smp import *
import os, openai
from openai import OpenAI
import os, sys
from .base import BaseAPI

APIBASES = {
'OFFICIAL': "https://api.openai.com/v1/chat/completions",
'INTERNAL': "https://ai-proxy.shlab.tech/internal"
}


Expand Down Expand Up @@ -41,6 +39,7 @@ def __init__(self,
verbose: bool = True,
system_prompt: str = None,
temperature: float = 0,
timeout: int = 60,
api_base: str = 'OFFICIAL',
max_tokens: int = 1024,
img_size: int = 512,
Expand All @@ -63,15 +62,18 @@ def __init__(self,
self.vision = False
if model == 'gpt-4-vision-preview':
self.vision = True
self.timeout = timeout

assert isinstance(openai_key, str) and openai_key.startswith('sk-'), f'Illegal openai_key {openai_key}. Please set the environment variable OPENAI_API_KEY to your openai key. '
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)

if api_base in APIBASES:
self.client = OpenAI(api_key=openai_key, base_url=APIBASES[api_base])
self.api_base = APIBASES[api_base]
elif api_base.startswith('http'):
self.api_base = api_base
else:
self.client = OpenAI(api_key=openai_key)

super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
self.logger.error("Unknown API Base. ")
sys.exit(-1)

# inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
# content can be a string or a list of image & text
Expand Down Expand Up @@ -124,27 +126,25 @@ def generate_inner(self, inputs, **kwargs) -> str:
self.logger.warning('Less than 100 tokens left, may exceed the context window with some additional meta symbols. ')
if max_tokens <= 0:
return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '


headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.openai_key}'}
payload = dict(
model=self.model,
messages=input_msgs,
max_tokens=max_tokens,
n=1,
temperature=temperature,
**kwargs)
response = requests.post(self.api_base, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
ret_code = response.status_code
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
answer = self.fail_msg
try:
response = self.client.chat.completions.create(
model=self.model,
messages=input_msgs,
max_tokens=max_tokens,
n=1,
stop=None,
temperature=temperature,
**kwargs)

result = response.choices[0].message.content.strip()
return 0, result, 'API Call Succeed'
resp_struct = json.loads(response.text)
answer = resp_struct['choices'][0]['message']['content'].strip()
except:
if self.verbose:
self.logger.warning(f'OPENAI KEY {self.openai_key} FAILED !!!')
try:
self.logger.warning(response)
except:
pass
return -1, self.fail_msg, 'API Call Failed'
pass
return ret_code, answer, response

def get_token_len(self, inputs) -> int:
import tiktoken
Expand All @@ -163,7 +163,6 @@ def get_token_len(self, inputs) -> int:
res += self.get_token_len(item)
return res


class GPT4V(OpenAIWrapper):

def generate(self, image_path, prompt, dataset=None):
Expand Down

0 comments on commit 5bbef87

Please sign in to comment.