Skip to content

Commit

Permalink
Changed MMLU Pro for Non-COT Version (#3108)
Browse files Browse the repository at this point in the history
  • Loading branch information
siyagoel authored Oct 31, 2024
1 parent 712ac23 commit b92b93f
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 4 deletions.
22 changes: 22 additions & 0 deletions src/helm/benchmark/run_specs/lite_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,28 @@ def get_mmlu_spec(subject: str, method: str = ADAPT_MULTIPLE_CHOICE_JOINT) -> Ru
)


@run_spec_function("mmlu_pro")
def get_mmlu_pro_spec(subject: str) -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.mmlu_pro.MMLUProScenario", args={"subject": subject}
)

adapter_spec = get_multiple_choice_adapter_spec(
method=ADAPT_MULTIPLE_CHOICE_JOINT,
instructions=f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.",
input_noun="Question",
output_noun="Answer",
)

return RunSpec(
name=f"mmlu_pro:subject={subject}",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_exact_match_metric_specs(),
groups=["mmlu_pro"],
)


@run_spec_function("gsm")
def get_gsm_spec() -> RunSpec:
scenario_spec = ScenarioSpec(class_name="helm.benchmark.scenarios.gsm_scenario.GSM8KScenario", args={})
Expand Down
72 changes: 72 additions & 0 deletions src/helm/benchmark/scenarios/mmlu_pro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Dict, List
from datasets import load_dataset

from helm.common.hierarchical_logger import hlog
from .scenario import Scenario, Instance, Reference, TRAIN_SPLIT, TEST_SPLIT, CORRECT_TAG, Input, Output


class MMLUProScenario(Scenario):
"""
The MMLU-Pro dataset is an advanced version of the Massive Multitask Language Understanding (MMLU)
benchmark, created to push the boundaries of language models' reasoning and comprehension skills.
Designed as a more challenging evaluation, it increases the answer options per question from four
to ten, significantly reducing the likelihood of correct random guesses. This update makes the
dataset better at distinguishing the capabilities of models on complex tasks.
MMLU-Pro emphasizes reasoning over simple factual recall by integrating diverse, intricate questions
across 14 domains, including subjects like biology, economics, law, and psychology. In addition, it
addresses limitations in the original MMLU by filtering out trivial questions, making it a more
robust benchmark. Performance comparisons suggest that models benefit from reasoning-based
approaches (such as Chain of Thought, or CoT) on MMLU-Pro, which contrasts with the original
MMLU where CoT didn’t show as much benefit. This makes MMLU-Pro especially suitable for evaluating
advanced models that rely on nuanced reasoning and comprehension skills​.
Dataset: https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro
Paper: https://arxiv.org/abs/2406.01574
"""

name = "mmlu_pro"
description = "Enhanced Massive Multitask Language Understanding with increased options and reasoning"
tags = ["knowledge", "multiple_choice", "reasoning"]

def __init__(self, subject: str):
super().__init__()
self.subject: str = subject

def process_csv(self, data, split: str) -> List[Instance]:
instances: List[Instance] = []
hlog(f"Processing data for {split} split")
for row in data:
question = row["question"]
answers = row["options"][:10] # Limit to 10 answers if necessary
correct_choice = row["answer"]
answers_dict = dict(zip(["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"], answers))
correct_answer = answers_dict[correct_choice]

def answer_to_reference(answer: str) -> Reference:
return Reference(Output(text=answer), tags=[CORRECT_TAG] if answer == correct_answer else [])

instance = Instance(
input=Input(text=question),
references=list(map(answer_to_reference, answers)),
split=split,
)
instances.append(instance)
return instances

def get_instances(self, output_path: str) -> List[Instance]:
# Load the MMLU-Pro dataset from Hugging Face
dataset = load_dataset("TIGER-Lab/MMLU-Pro")

# Process all the instances
instances: List[Instance] = []
splits: Dict[str, str] = {
"validation": TRAIN_SPLIT,
"test": TEST_SPLIT,
}
for hf_split, split in splits.items():
data = dataset[hf_split].filter(lambda x: x["category"] == self.subject)
print(f"Filtered instances in {hf_split}: {len(data)}")
instances.extend(self.process_csv(data, split))

return instances
61 changes: 61 additions & 0 deletions src/helm/benchmark/scenarios/test_mmlu_pro_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
from tempfile import TemporaryDirectory

from helm.benchmark.scenarios.mmlu_pro import MMLUProScenario
from helm.benchmark.scenarios.scenario import CORRECT_TAG, Input, Output, Reference


@pytest.mark.scenarios
def test_mmlu_pro_scenario():
with TemporaryDirectory() as tmpdir:
# Test for the "abstract_algebra" subject
scenario = MMLUProScenario(subject="math")
instances = scenario.get_instances(tmpdir)
# assert len(instances) == 116
assert instances[1].input == Input(
text="Let V be the set of all real polynomials p(x). Let transformations T, S be defined on V by T:p(x) -> xp(x) and S:p(x) -> p'(x) = d/dx p(x), and interpret (ST)(p(x)) as S(T(p(x))). Which of the following is true?" # noqa: E501
)

# Ensure it handles up to 10 answer options
assert instances[1].references == [
Reference(output=Output(text="ST + TS is the identity map of V onto itself."), tags=[]),
Reference(output=Output(text="TS = 0"), tags=[]),
Reference(output=Output(text="ST = 1"), tags=[]),
Reference(output=Output(text="ST - TS = 0"), tags=[]),
Reference(output=Output(text="ST = T"), tags=[]),
Reference(output=Output(text="ST = 0"), tags=[]),
Reference(output=Output(text="ST = TS"), tags=[]),
Reference(output=Output(text="ST - TS is the identity map of V onto itself."), tags=[CORRECT_TAG]),
Reference(output=Output(text="TS = T"), tags=[]),
Reference(output=Output(text="ST = S"), tags=[]),
]
assert instances[1].split == "train"

# Optional: check if the explanation is properly included (if provided in the dataset)
# assert hasattr(instances[0], "explanation")

# Test for the "anatomy" subject
scenario = MMLUProScenario(subject="health")
instances = scenario.get_instances(tmpdir)
# assert len(instances) == 154
assert instances[0].input == Input(
text="Which of the following is the body cavity that contains the pituitary gland?"
)

# Check references with more answer choices and correct tagging
assert instances[0].references == [
Reference(output=Output(text="Ventral"), tags=[]),
Reference(output=Output(text="Dorsal"), tags=[]),
Reference(output=Output(text="Buccal"), tags=[]),
Reference(output=Output(text="Thoracic"), tags=[]),
Reference(output=Output(text="Pericardial"), tags=[]),
Reference(output=Output(text="Abdominal"), tags=[]),
Reference(output=Output(text="Spinal"), tags=[]),
Reference(output=Output(text="Pelvic"), tags=[]),
Reference(output=Output(text="Pleural"), tags=[]),
Reference(output=Output(text="Cranial"), tags=[CORRECT_TAG]),
]
assert instances[0].split == "train"

# Again, check for the presence of an explanation
# assert hasattr(instances[0], "explanation")
6 changes: 3 additions & 3 deletions src/helm/benchmark/scenarios/test_mmlu_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ def test_mmlu_scenario():
scenario = MMLUScenario(subject="abstract_algebra")
instances = scenario.get_instances(tmpdir)
assert len(instances) == 116
assert instances[0].input == Input(text="Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.")
assert instances[0].references == [
assert instances[1].input == Input(text="Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.")
assert instances[1].references == [
Reference(output=Output(text="0"), tags=[]),
Reference(output=Output(text="1"), tags=[CORRECT_TAG]),
Reference(output=Output(text="2"), tags=[]),
Reference(output=Output(text="3"), tags=[]),
]
assert instances[0].split == "train"
assert instances[1].split == "train"

scenario = MMLUScenario(subject="anatomy")
instances = scenario.get_instances(tmpdir)
Expand Down
2 changes: 1 addition & 1 deletion src/helm/common/images_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def open_image(image_location: str) -> Image.Image:
"""
image: Image.Image
if is_url(image_location):
image = Image.open(requests.get(image_location, stream=True).raw)
image = Image.open(requests.get(image_location, stream=True).raw) # type: ignore
else:
image = Image.open(image_location)
return image.convert("RGB")
Expand Down

0 comments on commit b92b93f

Please sign in to comment.