Skip to content

Commit

Permalink
add gsm evaluator(wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
shuishen112 committed Dec 6, 2024
1 parent a198621 commit e64be33
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
7 changes: 7 additions & 0 deletions mttl/datamodule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,7 @@ def get_datamodule(args, for_generation=False, dataset_override=None):
HellaswagMultiChoiceDataModule,
)
from mttl.datamodule.mathqa_data_module import MathQADataConfig, MathQADataModule
from mttl.datamodule.gsm_data_module import GsmDataConfig, GsmDataModule
from mttl.datamodule.base import DatasetConfig
from mttl.datamodule.alpaca_data_module import (
AlpacaCodeDataModule,
Expand Down Expand Up @@ -1073,6 +1074,12 @@ def get_datamodule(args, for_generation=False, dataset_override=None):
**common_kwargs,
)
dm = MathQADataModule(config, for_generation=for_generation)
elif "gsm" in dataset:
config = GsmDataConfig(
**common_kwargs,
)
dm = GsmDataModule(config, for_generation=for_generation)

elif "alpaca_code" in dataset:
config = DatasetConfig(
**common_kwargs,
Expand Down
63 changes: 63 additions & 0 deletions mttl/evaluators/gsm_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os

from tqdm.auto import tqdm
from mttl.evaluators.base import GenerativeEvaluator, switch_to_eval_mode


class GsmEvaluator(GenerativeEvaluator):
def __init__(
self,
datamodule,
use_vllm=False,
generation_kwargs=None,
prepend_source=True,
split="test",
):
super().__init__(
datamodule=datamodule,
use_vllm=use_vllm,
generation_kwargs=generation_kwargs,
)

self.split = split
self.prepend_source = prepend_source
os.environ["HF_ALLOW_CODE_EVAL"] = "1"

@switch_to_eval_mode
def evaluate(
self,
model,
split=None,
subsample=-1,
num_batches=None,
verbose=True,
shuffle=False,
output_path=None,
):
dataloader = self.get_dataloader(split, subsample, shuffle=shuffle)

pbar = tqdm(
enumerate(dataloader),
total=len(dataloader),
)

all_predictions = []
all_targets = []

for num_batch, batch in pbar:
predictions = self.generate_for_batch(model, batch)

all_predictions.extend(predictions)
all_targets.extend(batch["target"])
breakpoint()
metrics = self.compute_metrics(all_predictions, all_targets)
return metrics

def compute_metrics(self, predictions, targets):
# compute the accuracy
correct = 0
for pred, target in zip(predictions, targets):
if pred == target:
correct += 1
accuracy = correct / len(predictions)
return accuracy

0 comments on commit e64be33

Please sign in to comment.