Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add run.py and simplify the evaluation. #12

Merged
merged 10 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions Custom_Benchmark_and_Model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# 🛠️ How to implement a new Benchmark / VLM in VLMEvalKit?

## Implement a new benchmark

Currently, we organize a benchmark as one single TSV file. During inference, the data file will be automatically downloaded to `$LMUData` (default path is `$HOME/LMUData`, if not set explicitly). All existing benchmark TSV files are handled by `TSVDataset` implemented in `vlmeval/utils/data_util.py`.

| Dataset Name \ Fields | index | image | image_path | question | hint | A | B | C | D | answer | category | l2-category | split |
| --------------------- | ----- | ----- | ---------- | -------- | ---- | ---- | ---- | ---- | ---- | ------ | -------- | ----------- | ----- |
| MMBench_DEV_CN/EN | √ | √ | | √ | √ | √ | √ | √ | √ | √ | √ | √ | √ |
| MMBench_TEST_CN/EN | √ | √ | | √ | √ | √ | √ | √ | √ | | √ | √ | √ |
| CCBench | √ | √ | | √ | | √ | √ | √ | √ | √ | √ | | |
| SEEDBench_IMG | √ | √ | | √ | | √ | √ | √ | √ | √ | √ | | |
| MME | √ | √ | | √ | | | | | | √ | √ | | |
| CORE_MM | √ | √ | √ | √ | | | | | | | √ | | |
| MMVet | √ | √ | | √ | | | | | | √ | √ | | |

<div align="center"><b>Table 1. TSV fields of supported datasets.</b></div>

**Intro to some fields:**

- **index:** Integer, Unique for each line in `tsv`
- **image:** the base64 of the image, you can use APIs implemented in `vlmeval/smp.py` for encoding and decoding:
- Encoding: `encode_image_to_base64 `(for PIL Image) / `encode_image_file_to_base64` (for image file path)
- Decoding: `decode_base64_to_image`(for PIL Image) / `decode_base64_to_image_file` (for image file path)

Besides, your dataset class **should implement the method `build_prompt(self, line, dataset=None)`**. Given line as a line number or one line in the TSV file, the function yields a dictionary `dict(image=image_path, text=prompt)`, including the image path and the prompt that will be fed to the VLMs.

## Implement a new model

All existing models are implemented in `vlmeval/vlm`. For a minimal model, your model class **should implement the method** `generate(image_path, prompt, dataset=None)`. In this function, you feed the image and prompt to your VLM and return the VLM prediction (which is a string). The optional argument `dataset` can be used as the flag for the model to switch among various inference strategies.

Besides, your model can support custom prompt building by implementing an optional method `build_prompt(line, dataset=None)`. In this function, the line is a dictionary that includes the necessary information of a data sample, while `dataset` can be used as the flag for the model to switch among various prompt building strategies.
120 changes: 34 additions & 86 deletions README.md

Large diffs are not rendered by default.

3 changes: 0 additions & 3 deletions cover.sh

This file was deleted.

100 changes: 100 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import torch
import torch.distributed as dist
from vlmeval.smp import *
from vlmeval.eval import MME_eval, MMVet_eval, multiple_choice_eval, MME_rating, MME_postproc
from vlmeval.infer import infer_data, prefetch_acc
from vlmeval.utils import TSVDataset
from vlmeval.config import supported_VLM

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, nargs='+', required=True)
parser.add_argument("--model", type=str, nargs='+', required=True)
parser.add_argument("--mode", type=str, default='all', choices=['all', 'infer'])
parser.add_argument("--nproc", type=str, default=4, help="Parallel API calling")
parser.add_argument("--verbose", action='store_true')
args = parser.parse_args()
return args

def main():
logger = get_logger('RUN')

args = parse_args()
assert len(args.data), "--data should be a list of data files"

rank, world_size = get_rank_and_world_size()
if world_size > 1:
torch.cuda.set_device(rank)
dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=5400))

for _, model_name in enumerate(args.model):
model = None
os.makedirs(model_name, exist_ok=True)
pred_root = model_name

for i, dataset_name in enumerate(args.data):
tmpl = f'{pred_root}/' + '{}' + f'{world_size}_{dataset_name}.pkl'
out_file = tmpl.format(rank)
result_file = f'{pred_root}/{model_name}_{dataset_name}.xlsx'

if model is None:
model = model_name # which is only a name

# CHECKER
if dataset_name == 'CORE_MM':
MULTI_IMG = getattr(supported_VLM[model_name].func, 'MULTI_IMG', False)
if not MULTI_IMG:
logger.error(f'Model {model_name} does not support the `multi_generate` interface, which is required for testing CORE_MM, skip it. ')
continue
if args.mode == 'all':
logger.error(f'Dataset {dataset_name} does not support `evaluation` now, will skip the evaluation. ')

if not osp.exists(result_file):
model = infer_data(model, dataset_name=dataset_name, out_file=out_file, verbose=args.verbose)
if world_size > 1:
dist.barrier()

if rank == 0:

data_all = {}
for i in range(world_size):
data_all.update(load(tmpl.format(i)))

data = TSVDataset(dataset_name).data
assert len(data_all) == len(data)
data['prediction'] = [data_all[x] for x in data['index']]
data.pop('image')

if dataset_name == 'MME':
data = MME_postproc(data)

dump(data, result_file)
for i in range(world_size):
os.remove(tmpl.format(i))

if rank == 0:
time.sleep(3)
res = None
if dataset_name == 'MME':
res = MME_rating(result_file)
elif dataset_name not in ['CORE_MM', 'MMVet']:
res = prefetch_acc(result_file)
else:
logger.warning(f'{dataset_name} is not handled by prefetch score calculator')
if res is not None:
logger.info(f'{model_name} prefetching: ')
logger.info(res)
dump(res, result_file.replace('.xlsx', '_prefetch.xlsx'))

if rank == 0 and args.mode == 'all':
if listinstr(['MMBench', 'CCBench', 'SEEDBench_IMG'], dataset_name):
multiple_choice_eval(result_file, dataset=dataset_name, model='chatgpt-0613', nproc=args.nproc, verbose=args.verbose)
elif dataset_name == 'MME':
MME_eval(result_file, model='chatgpt-0613', nproc=args.nproc, verbose=args.verbose)
elif dataset_name == 'MMVet':
MMVet_eval(result_file, model='gpt-4-turbo', nproc=args.nproc, verbose=args.verbose)
else:
logger.error(f'Dataset {dataset_name} is not handled by evaluator, will be skipped. ')

if __name__ == '__main__':
main()
4 changes: 4 additions & 0 deletions scripts/cover.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cp $DIR/../config.py $DIR/../vlmeval/
cp $DIR/../misc/* $DIR/../vlmeval/vlm/misc/
9 changes: 5 additions & 4 deletions vlmeval/chat_api/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import random as rd
from abc import abstractmethod
import warnings
from ..smp import get_logger

class BaseAPI:

Expand All @@ -18,13 +18,14 @@ def __init__(self,
self.kwargs = kwargs
self.verbose = verbose
self.fail_msg = fail_msg
self.logger = get_logger('ChatAPI')
if len(kwargs):
warnings.warn(f'BaseAPI received the following kwargs: {kwargs}')
warnings.warn(f'Will try to use them as kwargs for `generate`. ')
self.logger.info(f'BaseAPI received the following kwargs: {kwargs}')
self.logger.info(f'Will try to use them as kwargs for `generate`. ')

@abstractmethod
def generate_inner(self, inputs, **kwargs):
warnings.warn(f'For APIBase, generate_inner is an abstract method. ')
self.logger.warning(f'For APIBase, generate_inner is an abstract method. ')
assert 0, 'generate_inner not defined'
ret_code, answer, log = None, None, None
# if ret_code is 0, means succeed
Expand Down
10 changes: 5 additions & 5 deletions vlmeval/chat_api/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def __init__(self,

openai_key = os.environ.get('OPENAI_API_KEY', None) if openai_key is None else openai_key
self.openai_key = openai_key

assert isinstance(openai_key, str) and openai_key.startswith('sk-')
assert isinstance(openai_key, str) and openai_key.startswith('sk-'), f'Illegal openai_key {openai_key}. Please set the environment variable OPENAI_API_KEY to your openai key. '

if api_base in APIBASES:
self.client = OpenAI(api_key=openai_key, base_url=APIBASES[api_base])
Expand Down Expand Up @@ -86,7 +86,7 @@ def generate_inner(self, inputs, **kwargs) -> str:
context_window = GPT_context_window(self.model)
max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
if 0 < max_tokens <= 100:
warnings.warn('Less than 100 tokens left, may exceed the context window with some additional meta symbols. ')
self.logger.warning('Less than 100 tokens left, may exceed the context window with some additional meta symbols. ')
if max_tokens <= 0:
return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '

Expand All @@ -104,9 +104,9 @@ def generate_inner(self, inputs, **kwargs) -> str:
return 0, result, 'API Call Succeed'
except:
if self.verbose:
warnings.warn(f'OPENAI KEY {self.openai_key} FAILED !!!')
self.logger.warning(f'OPENAI KEY {self.openai_key} FAILED !!!')
try:
warnings.warn(response)
self.logger.warning(response)
except:
pass
return -1, self.fail_msg, 'API Call Failed'
Expand Down
2 changes: 2 additions & 0 deletions vlmeval/chat_api/gpt_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(self,
if 'KEYS' in os.environ and osp.exists(os.environ['KEYS']):
keys = load(os.environ['KEYS'])
headers['alles-apin-token'] = keys.get('alles-apin-token', '')
elif 'ALLES' in os.environ:
headers['alles-apin-token'] = os.environ['ALLES']
self.headers = headers
self.temperature = temperature
self.timeout = timeout
Expand Down
14 changes: 7 additions & 7 deletions vlmeval/chat_api/hf_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
import warnings
import os.path as osp
import torch
from .base import BaseAPI
from ..smp import *

def get_gpu_num(model_name):
Expand All @@ -29,7 +27,7 @@ def get_gpu_num(model_name):
]
Auto_model = ['chatglm']

class HFChatModel(BaseAPI):
class HFChatModel:

def _get_context_length(self, model, model_path):
# By default, we use model.config.seq_length
Expand All @@ -50,7 +48,7 @@ def _get_context_length_robust(self, model, model_path):
context_window = self._get_context_length(model, model_path)
return context_window
except:
warnings.warn(
self.logger.critical(
"Failed to extract context_window information from config / generation_config. "
"Please read the above code and check if the logic works for you model path"
)
Expand All @@ -61,11 +59,13 @@ def __init__(self,
system_prompt: str=None,
**kwargs):

self.logger = get_logger('HFChatModel')
if 'vicuna' in model_path.lower():
try:
from fastchat.model import get_conversation_template
except:
warnings.warn("Please install fastchat first to use vicuna. ")
self.logger.critical("Please install fastchat first to use vicuna. ")
exit(-1)

self.explicit_device = kwargs.pop('device', None)

Expand All @@ -81,7 +81,7 @@ def __init__(self,
from transformers.generation import GenerationConfig

if model_path not in validated_llms:
warnings.warn(f"{model_path} not in validated LLMs, may have inference troubles. ")
self.logger.warning(f"{model_path} not in validated LLMs, may have inference troubles. ")

self.model_path = model_path
if listinstr(Auto_model, model_path):
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(self,
self.answer_buffer = 192
self.system_prompt = system_prompt
for k, v in kwargs.items():
warnings.warn(f'Following args are passed and will be used as generation hyper-paras (If not set specifically), {k}: {v}. ')
self.logger.info(f'Following args are passed and will be used as generation hyper-paras (If not set specifically), {k}: {v}. ')
self.kwargs = kwargs

def generate_str(self, input, **kwargs):
Expand Down
4 changes: 3 additions & 1 deletion vlmeval/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .mme_eval import MME_rating, MME_postproc
from .mme_eval import MME_rating, MME_postproc, MME_eval
from .mmvet_eval import MMVet_eval
from .multiple_choice import multiple_choice_eval
41 changes: 18 additions & 23 deletions vlmeval/eval/mme_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,26 +91,25 @@ def MME_auxeval_tup(tup):
model, line = tup
return MME_auxeval(model, line)

def MME_eval(args):
eval_file = args.data
rd.seed(2680)

suffix = eval_file.split('.')[-1]
data = load(args.data)
def MME_eval(eval_file, model='chatgpt-0613', nproc=4, verbose=False):
logger = get_logger('Evaluation')

data = load(eval_file)
if 'raw_prediction' not in data:
data = MME_postproc(data)

preds_map = {x: y for x, y in zip(data['index'], data['prediction'])}
unknown = data[data['prediction'] == 'Unknown']
storage = args.data.replace('.xlsx', '_auxmatch.xlsx')
storage = eval_file.replace('.xlsx', '_auxmatch.xlsx')

if not osp.exists(storage):
assert args.model == 'gpt-3.5-turbo-0613'
assert model == 'chatgpt-0613'
model_name = 'gpt-3.5-turbo-0613'

if INTERNAL:
model = OpenAIWrapperInternal(args.model, verbose=args.verbose)
model = OpenAIWrapperInternal(model_name, verbose=verbose)
else:
model = OpenAIWrapper(args.model, verbose=args.verbose)
model = OpenAIWrapper(model_name, verbose=verbose)

lt = len(unknown)
lines = [unknown.iloc[i: i + 1] for i in range(lt)]
Expand All @@ -119,40 +118,36 @@ def MME_eval(args):

if len(tups):
# Do not save temporary file due to the fast speed
res = track_progress_rich(
MME_auxeval,
tups,
nproc=args.nproc,
chunksize=args.nproc)
res = track_progress_rich(MME_auxeval, tups, nproc=nproc, chunksize=nproc)

for k, v in zip(indices, res):
preds_map[k] = v

data['prediction'] = [preds_map[idx] for idx in data['index']]
dump(data, storage)
else:
logger.warning(f"GPT matching file {storage} already exists, will reuse it in MME_eval. ")

data = load(storage)
data["score"] = (data["answer"] == data["prediction"])
dump(data, storage)
score = MME_rating(storage)
score_tgt = storage.replace('auxmatch.xlsx', 'score.csv')
dump(score, score_tgt)

logger.info(f'MME_eval successfully finished evaluating {eval_file}, results saved in {score_tgt}')
logger.info('Score: ')
logger.info(score)
return score

def parse_args():
parser = argparse.ArgumentParser(description="Inference LLM Answers. ")
parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ")
parser.add_argument("--model", type=str, help="The LLM (GPT) used for inference. ", default="gpt-3.5-turbo-0613", choices=['gpt-3.5-turbo-0613'])
parser.add_argument("--nproc", type=int, default=6)
parser.add_argument("--model", type=str, help="The LLM (GPT) used for inference. ", default="chatgpt-0613", choices=['chatgpt-0613'])
parser.add_argument("--nproc", type=int, default=4)
parser.add_argument("--verbose", action='store_true')
args = parser.parse_args()
return args

if __name__ == '__main__':
args = parse_args()

suffix = args.data.split('.')[-1]
log_pth = args.data.replace('.' + suffix, f'_{args.model}_eval.log')
acc = MME_eval(args)
print(acc)
acc = MME_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose)
Loading