Skip to content

Commit

Permalink
[Improvement] Support non-contiguous choices (open-compass#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick authored Jan 12, 2024
1 parent eb6e31a commit 39c60fb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 30 deletions.
29 changes: 6 additions & 23 deletions vlmeval/evaluate/multiple_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,6 @@ def report_acc(df):
res[ab_name] = [np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']]
return pd.DataFrame(res)

def extract_options(item):
options = []
for c in list(string.ascii_uppercase):
if c in item and not pd.isna(item[c]):
options.append(item[c])
else:
return options
return options

def build_prompt(question, options, prediction):
tmpl = (
"You are an AI assistant who will help me to match an answer with several options of a single-choice question. "
Expand Down Expand Up @@ -110,15 +101,14 @@ def prefetch_answer(item):
def extract_answer_from_item(model, item):
logger = get_logger('Evaluation')
# It will return: (pred, raw, llm_time)
options = extract_options(item)
option_str = build_options(options)
choices = build_choices(item)
option_str = build_option_str(choices)

if cn_string(item['question']):
prompt = build_prompt_cn(item['question'], option_str, item['prediction'])
else:
prompt = build_prompt(item['question'], option_str, item['prediction'])
retry = 3
choices = build_choices(item)

ret = can_infer(item['prediction'], choices)
if ret:
Expand All @@ -127,25 +117,18 @@ def extract_answer_from_item(model, item):
while retry:
ans = model.generate(prompt)
if 'Failed to obtain answer via API' in ans:
msg = 'GPT API failed to answer. '
logger.warning(msg)
retry -= 1
logger.warning('GPT API failed to answer. ')
else:
ret = can_infer(ans, choices)
if ret:
return dict(opt=ret, log=ans)
else:
logger.warning(f'Output includes 0 / > 1 letter among candidates {set(choices)} and Z: {ans}')
retry -= 1
retry -= 1

if retry == 0:
num_options = sum([ch in item for ch in string.ascii_uppercase])
if num_options >= 2:
chars = string.ascii_uppercase[:num_options]
chars = chars + 'Z'
num_options += 1
tmp = rd.randint(0, num_options - 1)
return dict(opt=chars[tmp], log='Failed to predict, thus randomly generate one. ')
options = list(choices) + ['Z'] if 'Z' not in choices else []
return dict(opt=rd.choice(options), log='Failed to predict, thus randomly generate one. ')

def prefetch_sub_data(sub_data, answer_map, verbose=False):
lt = len(sub_data)
Expand Down
11 changes: 4 additions & 7 deletions vlmeval/smp.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,11 @@ def cn_string(s):
except ImportError:
pass

def build_options(option_list):
chars = string.ascii_uppercase
def build_option_str(option_dict):
s = 'There are several options: \n'
for c, opt in zip(chars, option_list):
if not pd.isna(opt):
s += f'{c}. {opt}\n'
else:
return s
for c, content in option_dict.items():
if not pd.isna(content):
s += f'{c}. {content}\n'
return s

def timestr(second=True, minute=False):
Expand Down

0 comments on commit 39c60fb

Please sign in to comment.