-
Notifications
You must be signed in to change notification settings - Fork 483
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add Longbenchv2 support (#1801)
* Create eval_longbenchv2.py * Create longbenchv2_gen.py * Update __init__.py * Create longbenchv2.py * Update datasets_info.py * update * update * update * update * update * update --------- Co-authored-by: abrohamLee <146956824+abrohamLee@users.noreply.github.com>
- Loading branch information
1 parent
f322043
commit 117dc50
Showing
8 changed files
with
375 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
# Models | ||
from opencompass.configs.models.chatglm.lmdeploy_glm4_9b_chat import ( | ||
models as lmdeploy_glm4_9b_chat_model, | ||
) | ||
from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_8b_instruct import ( | ||
models as lmdeploy_llama3_1_8b_instruct_model, | ||
) | ||
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import ( | ||
models as lmdeploy_qwen2_5_7b_instruct_model, | ||
) | ||
|
||
# Datasets | ||
from opencompass.configs.datasets.longbenchv2.longbenchv2_gen import ( | ||
LongBenchv2_datasets as LongBenchv2_datasets, | ||
) | ||
|
||
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) | ||
|
||
models = sum([v for k, v in locals().items() if k.endswith('_model')], []) | ||
|
||
for model in models: | ||
model['max_seq_len'] = 128 * 1024 | ||
model['engine_config']['session_len'] = 128 * 1024 | ||
model['engine_config']['tp'] = 2 | ||
model['run_cfg']['num_gpus'] = 2 | ||
# Drop middle tokens to make input length shorter than session_len, use 128k to keep sync with Longbenchv2 original code | ||
# Drop middle now only support LMDeploy models | ||
model['drop_middle'] = True | ||
|
||
|
||
work_dir = './outputs/longbenchv2' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .longbenchv2_gen_75fbba import LongBenchv2_datasets |
43 changes: 43 additions & 0 deletions
43
opencompass/configs/datasets/longbenchv2/longbenchv2_gen_75fbba.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from opencompass.openicl.icl_prompt_template import PromptTemplate | ||
from opencompass.openicl.icl_retriever import ZeroRetriever | ||
from opencompass.openicl.icl_inferencer import GenInferencer | ||
from opencompass.datasets import LongBenchv2Dataset, LongBenchv2Evaluator | ||
from opencompass.utils.text_postprocessors import first_option_postprocess | ||
|
||
LongBenchv2_reader_cfg = dict( | ||
input_columns=['context', 'question', 'choice_A', 'choice_B', 'choice_C', 'choice_D', 'difficulty', 'length'], | ||
output_column='answer', | ||
) | ||
|
||
LongBenchv2_infer_cfg = dict( | ||
prompt_template=dict( | ||
type=PromptTemplate, | ||
template=dict( | ||
round=[ | ||
dict( | ||
role='HUMAN', | ||
prompt='Please read the following text and answer the questions below.\n <text> \n {context} \n </text> \n \n What is the correct answer to this question: {question} \n \n Choices: \n (A) {choice_A} \n (B) {choice_B} \n (C) {choice_C} \n (D) {choice_D} \n Let’s think step by step. Based on the above, what is the single, most likely answer choice? Format your response as follows: "The correct answer is (insert answer here)', | ||
), | ||
], | ||
), | ||
), | ||
retriever=dict(type=ZeroRetriever), | ||
inferencer=dict(type=GenInferencer), | ||
) | ||
|
||
LongBenchv2_eval_cfg = dict( | ||
evaluator=dict(type=LongBenchv2Evaluator), | ||
pred_role='BOT', | ||
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD') | ||
) | ||
|
||
LongBenchv2_datasets = [ | ||
dict( | ||
type=LongBenchv2Dataset, | ||
abbr='LongBenchv2', | ||
path='opencompass/longbenchv2', | ||
reader_cfg=LongBenchv2_reader_cfg, | ||
infer_cfg=LongBenchv2_infer_cfg, | ||
eval_cfg=LongBenchv2_eval_cfg, | ||
) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from opencompass.datasets import MusrDataset, MusrEvaluator | ||
from opencompass.openicl import PromptTemplate, ZeroRetriever, GenInferencer | ||
|
||
|
||
DATASET_CONFIGS = { | ||
'murder_mysteries': { | ||
'abbr': 'musr_murder_mysteries', | ||
'name': 'murder_mysteries', | ||
'path': 'opencompass/musr', | ||
'reader_cfg': dict( | ||
input_columns=['context', 'question_text', 'question', 'answer', 'choices', 'choices_str', 'intermediate_trees', 'intermediate_data', 'prompt', 'system_prompt', 'gold_answer', 'scidx', 'self_consistency_n', 'ablation_name'], | ||
output_column='gold_answer', | ||
), | ||
'infer_cfg': dict( | ||
prompt_template=dict( | ||
type=PromptTemplate, | ||
template=dict( | ||
begin=[ | ||
dict( | ||
role='SYSTEM', | ||
fallback_role='HUMAN', | ||
prompt='{system_prompt}' | ||
) | ||
], | ||
round=[ | ||
dict( | ||
role='HUMAN', | ||
prompt='{prompt}' | ||
), | ||
] | ||
), | ||
), | ||
retriever=dict(type=ZeroRetriever), | ||
inferencer=dict(type=GenInferencer), | ||
), | ||
'eval_cfg': dict( | ||
evaluator=dict( | ||
type=MusrEvaluator, | ||
answer_index_modifier=1, | ||
self_consistency_n=1 | ||
), | ||
), | ||
}, | ||
'object_placements': { | ||
'abbr': 'musr_object_placements', | ||
'name': 'object_placements', | ||
'path': 'opencompass/musr', | ||
'reader_cfg': dict( | ||
input_columns=['context', 'question_text', 'question', 'answer', 'choices', 'choices_str', 'intermediate_trees', 'intermediate_data', 'prompt', 'system_prompt', 'gold_answer', 'scidx', 'self_consistency_n', 'ablation_name'], | ||
output_column='gold_answer', | ||
), | ||
'infer_cfg': dict( | ||
prompt_template=dict( | ||
type=PromptTemplate, | ||
template=dict( | ||
begin=[ | ||
dict( | ||
role='SYSTEM', | ||
fallback_role='HUMAN', | ||
prompt='{system_prompt}' | ||
) | ||
], | ||
round=[ | ||
dict( | ||
role='HUMAN', | ||
prompt='{prompt}' | ||
), | ||
] | ||
), | ||
), | ||
retriever=dict(type=ZeroRetriever), | ||
inferencer=dict(type=GenInferencer), | ||
), | ||
'eval_cfg': dict( | ||
evaluator=dict( | ||
type=MusrEvaluator, | ||
answer_index_modifier=1, | ||
self_consistency_n=1 | ||
), | ||
), | ||
}, | ||
'team_allocation': { | ||
'abbr': 'musr_team_allocation', | ||
'name': 'team_allocation', | ||
'path': 'opencompass/musr', | ||
'reader_cfg': dict( | ||
input_columns=['context', 'question_text', 'question', 'answer', 'choices', 'choices_str', 'intermediate_trees', 'intermediate_data', 'prompt', 'system_prompt', 'gold_answer', 'scidx', 'self_consistency_n', 'ablation_name'], | ||
output_column='gold_answer', | ||
), | ||
'infer_cfg': dict( | ||
prompt_template=dict( | ||
type=PromptTemplate, | ||
template=dict( | ||
begin=[ | ||
dict( | ||
role='SYSTEM', | ||
fallback_role='HUMAN', | ||
prompt='{system_prompt}' | ||
) | ||
], | ||
round=[ | ||
dict( | ||
role='HUMAN', | ||
prompt='{prompt}' | ||
), | ||
] | ||
), | ||
), | ||
retriever=dict(type=ZeroRetriever), | ||
inferencer=dict(type=GenInferencer), | ||
), | ||
'eval_cfg': dict( | ||
evaluator=dict( | ||
type=MusrEvaluator, | ||
answer_index_modifier=1, | ||
self_consistency_n=1 | ||
), | ||
), | ||
}, | ||
} | ||
|
||
|
||
musr_datasets = [] | ||
|
||
for config in DATASET_CONFIGS.values(): | ||
dataset = dict( | ||
abbr=config['abbr'], | ||
type=MusrDataset, | ||
path=config['path'], | ||
name=config['name'], | ||
reader_cfg=config['reader_cfg'], | ||
infer_cfg=config['infer_cfg'], | ||
eval_cfg=config['eval_cfg'], | ||
) | ||
musr_datasets.append(dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from datasets import Dataset, load_dataset | ||
|
||
from opencompass.openicl.icl_evaluator import BaseEvaluator | ||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET | ||
from opencompass.utils import get_data_path | ||
|
||
from .base import BaseDataset | ||
|
||
|
||
@LOAD_DATASET.register_module() | ||
class LongBenchv2Dataset(BaseDataset): | ||
|
||
@staticmethod | ||
def load(path: str): | ||
path = get_data_path(path) | ||
dataset = load_dataset('json', data_files=path) | ||
|
||
split = 'train' | ||
raw_data = [] | ||
for i in range(len(dataset[split])): | ||
question = dataset[split]['question'][i] | ||
context = dataset[split]['context'][i] | ||
answer = dataset[split]['answer'][i] | ||
choice_A = dataset[split]['choice_A'][i] | ||
choice_B = dataset[split]['choice_B'][i] | ||
choice_C = dataset[split]['choice_C'][i] | ||
choice_D = dataset[split]['choice_D'][i] | ||
difficulty = dataset[split]['difficulty'][i] | ||
length = dataset[split]['length'][i] | ||
raw_data.append({ | ||
'question': question, | ||
'context': context, | ||
'answer': answer, | ||
'choice_A': choice_A, | ||
'choice_B': choice_B, | ||
'choice_C': choice_C, | ||
'choice_D': choice_D, | ||
'difficulty': difficulty, | ||
'length': length | ||
}) | ||
dataset['test'] = Dataset.from_list(raw_data) | ||
return dataset | ||
|
||
|
||
@ICL_EVALUATORS.register_module() | ||
class LongBenchv2Evaluator(BaseEvaluator): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def score(self, predictions, references, test_set): | ||
if not test_set: | ||
raise ValueError('test set is empty') | ||
|
||
metrics = { | ||
'total': { | ||
'correct': 0, | ||
'total': 0 | ||
}, | ||
'difficulty': { | ||
'easy': { | ||
'correct': 0, | ||
'total': 0 | ||
}, | ||
'hard': { | ||
'correct': 0, | ||
'total': 0 | ||
} | ||
}, | ||
'length': { | ||
'short': { | ||
'correct': 0, | ||
'total': 0 | ||
}, | ||
'medium': { | ||
'correct': 0, | ||
'total': 0 | ||
}, | ||
'long': { | ||
'correct': 0, | ||
'total': 0 | ||
} | ||
} | ||
} | ||
|
||
for i, (pred, ref, | ||
sample) in enumerate(zip(predictions, references, test_set)): | ||
is_correct = (pred == ref) | ||
|
||
metrics['total']['total'] += 1 | ||
if is_correct: | ||
metrics['total']['correct'] += 1 | ||
|
||
difficulty = sample.get('difficulty', 'unknown') | ||
if difficulty in metrics['difficulty']: | ||
metrics['difficulty'][difficulty]['total'] += 1 | ||
if is_correct: | ||
metrics['difficulty'][difficulty]['correct'] += 1 | ||
|
||
length = sample.get('length', 'unknown') | ||
if length in metrics['length']: | ||
metrics['length'][length]['total'] += 1 | ||
if is_correct: | ||
metrics['length'][length]['correct'] += 1 | ||
|
||
results = { | ||
'accuracy': | ||
metrics['total']['correct'] / metrics['total']['total'] * 100 | ||
} | ||
|
||
for diff in ['easy', 'hard']: | ||
if metrics['difficulty'][diff]['total'] > 0: | ||
acc = metrics['difficulty'][diff]['correct'] / metrics[ | ||
'difficulty'][diff]['total'] * 100 | ||
results[f'accuracy_{diff}'] = acc | ||
|
||
for length in ['short', 'medium', 'long']: | ||
if metrics['length'][length]['total'] > 0: | ||
acc = metrics['length'][length]['correct'] / metrics['length'][ | ||
length]['total'] * 100 | ||
results[f'accuracy_{length}'] = acc | ||
|
||
return results |
Oops, something went wrong.