From 7dee93fb3e6943fdc20d1fdde099edf39e464e83 Mon Sep 17 00:00:00 2001 From: iyuge2 Date: Thu, 16 May 2024 17:12:11 +0800 Subject: [PATCH] [Model] Feature add glmv (#201) * add glm vision api * add glm vision api with pre-commit --- vlmeval/api/__init__.py | 3 +- vlmeval/api/glm_vision.py | 101 ++++++++++++++++++++++++++++++++++++++ vlmeval/config.py | 2 + 3 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 vlmeval/api/glm_vision.py diff --git a/vlmeval/api/__init__.py b/vlmeval/api/__init__.py index 82263b33b..5e513d35b 100644 --- a/vlmeval/api/__init__.py +++ b/vlmeval/api/__init__.py @@ -7,9 +7,10 @@ from .stepai import Step1V_INT from .claude import Claude_Wrapper, Claude3V from .reka import Reka +from .glm_vision import GLMVisionAPI __all__ = [ 'OpenAIWrapper', 'HFChatModel', 'OpenAIWrapperInternal', 'GeminiWrapper', 'GPT4V', 'GPT4V_Internal', 'GeminiProVision', 'QwenVLWrapper', 'QwenVLAPI', - 'QwenAPI', 'Claude3V', 'Claude_Wrapper', 'Reka', 'Step1V_INT' + 'QwenAPI', 'Claude3V', 'Claude_Wrapper', 'Reka', 'Step1V_INT', 'GLMVisionAPI' ] diff --git a/vlmeval/api/glm_vision.py b/vlmeval/api/glm_vision.py new file mode 100644 index 000000000..8d5828f4d --- /dev/null +++ b/vlmeval/api/glm_vision.py @@ -0,0 +1,101 @@ +from vlmeval.smp import * +from vlmeval.api.base import BaseAPI +from vlmeval.utils.dataset import DATASET_TYPE +from vlmeval.smp.vlm import encode_image_file_to_base64 + + +class GLMVisionWrapper(BaseAPI): + + is_api: bool = True + + def __init__(self, + model: str, + retry: int = 5, + wait: int = 5, + key: str = None, + verbose: bool = True, + system_prompt: str = None, + max_tokens: int = 1024, + proxy: str = None, + **kwargs): + + self.model = model + self.fail_msg = 'Failed to obtain answer via API. ' + self.default_params = { + 'top_p': 0.6, + 'top_k': 2, + 'temperature': 0.8, + 'repetition_penalty': 1.1, + 'best_of': 1, + 'do_sample': True, + 'stream': False, + 'max_tokens': max_tokens + } + if key is None: + key = os.environ.get('GLMV_API_KEY', None) + assert key is not None, ( + 'Please set the API Key (obtain it here: ' + 'https://open.bigmodel.cn/dev/howuse/introduction)' + ) + self.key = key + super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs) + + def image_to_base64(self, image_path): + import base64 + with open(image_path, 'rb') as image_file: + encoded_string = base64.b64encode(image_file.read()) + return encoded_string.decode('utf-8') + + def build_msgs(self, msgs_raw, system_prompt=None, dataset=None): + msgs = cp.deepcopy(msgs_raw) + content = [] + text = '' + for i, msg in enumerate(msgs): + if msg['type'] == 'text': + text += msg['value'] + elif msg['type'] == 'image': + content.append(dict(type='image_url', image_url=dict(url=encode_image_file_to_base64(msg['value'])))) + if DATASET_TYPE(dataset) in ['multi-choice', 'Y/N']: + text += '\nShort Answer.' + content.append(dict(type='text', text=text)) + ret = [dict(role='user', content=content)] + return ret + + def generate_inner(self, inputs, **kwargs) -> str: + assert isinstance(inputs, str) or isinstance(inputs, list) + inputs = [inputs] if isinstance(inputs, str) else inputs + messages = self.build_msgs(msgs_raw=inputs, dataset=kwargs['dataset']) + + url = 'https://api.chatglm.cn/v1/chat/completions' + headers = { + 'Content-Type': 'application/json', + 'Request-Id': 'remote-test', + 'Authorization': f'Bearer {self.key}' + } + payload = { + 'model': self.model, + 'messages': messages, + **self.default_params + } + response = requests.post(url, headers=headers, data=json.dumps(payload), verify=False) + output = [] + try: + assert response.status_code == 200 + for line in response.iter_lines(): + data = json.loads(line.decode('utf-8').lstrip('data: ')) + output.append(data['choices'][0]['message']['content']) + answer = ''.join(output).replace('', '') + if self.verbose: + self.logger.info(f'inputs: {inputs}\nanswer: {answer}') + return 0, answer, 'Succeeded! ' + except Exception as err: + if self.verbose: + self.logger.error(err) + self.logger.error(f'The input messages are {inputs}.') + return -1, self.fail_msg, '' + + +class GLMVisionAPI(GLMVisionWrapper): + + def generate(self, message, dataset=None): + return super(GLMVisionAPI, self).generate(message, dataset=dataset) diff --git a/vlmeval/config.py b/vlmeval/config.py index acaf100ad..9d3c427c9 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -55,6 +55,8 @@ 'Claude3V_Opus': partial(Claude3V, model='claude-3-opus-20240229', temperature=0, retry=10), 'Claude3V_Sonnet': partial(Claude3V, model='claude-3-sonnet-20240229', temperature=0, retry=10), 'Claude3V_Haiku': partial(Claude3V, model='claude-3-haiku-20240307', temperature=0, retry=10), + # GLM4V + 'GLM4V': partial(GLMVisionAPI, model='glm4v-biz-eval', temperature=0, retry=10), } xtuner_series = {