diff --git a/mttl/datamodule/base.py b/mttl/datamodule/base.py index 877bee49..8557a860 100644 --- a/mttl/datamodule/base.py +++ b/mttl/datamodule/base.py @@ -945,6 +945,7 @@ def get_datamodule(args, for_generation=False, dataset_override=None): HellaswagMultiChoiceDataModule, ) from mttl.datamodule.mathqa_data_module import MathQADataConfig, MathQADataModule + from mttl.datamodule.gsm_data_module import GsmDataConfig, GsmDataModule from mttl.datamodule.base import DatasetConfig from mttl.datamodule.alpaca_data_module import ( AlpacaCodeDataModule, @@ -1073,6 +1074,12 @@ def get_datamodule(args, for_generation=False, dataset_override=None): **common_kwargs, ) dm = MathQADataModule(config, for_generation=for_generation) + elif "gsm" in dataset: + config = GsmDataConfig( + **common_kwargs, + ) + dm = GsmDataModule(config, for_generation=for_generation) + elif "alpaca_code" in dataset: config = DatasetConfig( **common_kwargs, diff --git a/mttl/evaluators/gsm_evaluator.py b/mttl/evaluators/gsm_evaluator.py new file mode 100644 index 00000000..6931415d --- /dev/null +++ b/mttl/evaluators/gsm_evaluator.py @@ -0,0 +1,63 @@ +import os + +from tqdm.auto import tqdm +from mttl.evaluators.base import GenerativeEvaluator, switch_to_eval_mode + + +class GsmEvaluator(GenerativeEvaluator): + def __init__( + self, + datamodule, + use_vllm=False, + generation_kwargs=None, + prepend_source=True, + split="test", + ): + super().__init__( + datamodule=datamodule, + use_vllm=use_vllm, + generation_kwargs=generation_kwargs, + ) + + self.split = split + self.prepend_source = prepend_source + os.environ["HF_ALLOW_CODE_EVAL"] = "1" + + @switch_to_eval_mode + def evaluate( + self, + model, + split=None, + subsample=-1, + num_batches=None, + verbose=True, + shuffle=False, + output_path=None, + ): + dataloader = self.get_dataloader(split, subsample, shuffle=shuffle) + + pbar = tqdm( + enumerate(dataloader), + total=len(dataloader), + ) + + all_predictions = [] + all_targets = [] + + for num_batch, batch in pbar: + predictions = self.generate_for_batch(model, batch) + + all_predictions.extend(predictions) + all_targets.extend(batch["target"]) + breakpoint() + metrics = self.compute_metrics(all_predictions, all_targets) + return metrics + + def compute_metrics(self, predictions, targets): + # compute the accuracy + correct = 0 + for pred, target in zip(predictions, targets): + if pred == target: + correct += 1 + accuracy = correct / len(predictions) + return accuracy