-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding Benchmark and Eval classes (#25)
* adding Benchmark and Eval classes * simplifying benchmarking * updating snapshot
- Loading branch information
Showing
9 changed files
with
168 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# pyright: strict, reportMissingTypeStubs=false | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional, Sequence | ||
|
||
from transformers import GenerationConfig | ||
|
||
from repepo.algorithms.base import BaseAlgorithm | ||
from repepo.core.evaluate import EvalPrediction, EvalResult, Evaluator | ||
from repepo.core.format import AbstractFormatter, InputOutputFormatter | ||
from repepo.core.pipeline import Pipeline | ||
|
||
from repepo.core.types import Dataset, Model, Tokenizer | ||
|
||
|
||
@dataclass | ||
class Benchmark: | ||
name: str | ||
train_dataset: Dataset | ||
test_dataset: Dataset | ||
evaluators: list[Evaluator] | ||
|
||
|
||
def train_and_evaluate_benchmark( | ||
model: Model, | ||
tokenizer: Tokenizer, | ||
algorithms: Sequence[BaseAlgorithm], | ||
benchmark: Benchmark, | ||
formatter: Optional[AbstractFormatter] = None, | ||
generation_config: Optional[GenerationConfig] = None, | ||
) -> EvalResult: | ||
# set up pipeline | ||
pipeline = Pipeline(model, tokenizer, formatter=formatter or InputOutputFormatter()) | ||
|
||
# train pipeline | ||
for algorithm in algorithms: | ||
pipeline = algorithm.run(pipeline, benchmark.train_dataset) | ||
|
||
# evaluate | ||
predictions: list[EvalPrediction] = [] | ||
# TODO: support batching | ||
for example in benchmark.test_dataset: | ||
output = pipeline.generate(example, generation_config=generation_config) | ||
predictions.append(EvalPrediction(example, output)) | ||
metrics: dict[str, float] = {} | ||
for evaluator in benchmark.evaluators: | ||
metrics.update(evaluator(predictions)) | ||
return EvalResult(predictions, metrics) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# pyright: strict | ||
|
||
|
||
from dataclasses import dataclass | ||
from statistics import mean | ||
from typing import Callable, Sequence | ||
from repepo.core.types import Example | ||
|
||
|
||
@dataclass | ||
class EvalPrediction: | ||
example: Example | ||
output: str | ||
|
||
|
||
@dataclass | ||
class EvalResult: | ||
predictions: list[EvalPrediction] | ||
metrics: dict[str, float] | ||
|
||
|
||
Evaluator = Callable[[Sequence[EvalPrediction]], dict[str, float]] | ||
|
||
|
||
class AccuracyEvaluator: | ||
def score_prediction(self, prediction: EvalPrediction) -> float: | ||
"""Score a single prediction, 1 if correct, 0 otherwise.""" | ||
expected = prediction.example.output | ||
# the output might be longer than the expected depending on how many tokens we generate | ||
# so just verify that the expected output is a prefix of the generated output | ||
is_correct = prediction.output.strip().startswith(expected.strip()) | ||
return 1.0 if is_correct else 0.0 | ||
|
||
def __call__(self, predictions: Sequence[EvalPrediction]) -> dict[str, float]: | ||
pred_results = [self.score_prediction(pred) for pred in predictions] | ||
return {"accuracy": mean(pred_results)} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pytest | ||
from transformers import GenerationConfig | ||
|
||
from repepo.algorithms.icl import InContextLearning | ||
from repepo.core.benchmark import Benchmark, train_and_evaluate_benchmark | ||
from repepo.core.evaluate import AccuracyEvaluator | ||
from repepo.core.types import Dataset, Example, Tokenizer, Model | ||
|
||
|
||
def test_evaluate_benchmark(larger_model: Model, larger_tokenizer: Tokenizer) -> None: | ||
train_dataset: Dataset = [ | ||
Example("", "Paris is located in the country of", "France"), | ||
Example("", "Shanghai is located in the country of", "China"), | ||
Example("", "Tokyo is located in the country of", "Japan"), | ||
Example("", "London is located in the country of", "England"), | ||
] | ||
test_dataset = [ | ||
Example("", "Kyoto is located in the country of", "Japan"), | ||
Example("", "Beijing is located in the country of", "China"), | ||
Example("", "FakePlace is located in the country of", "WrongAnswer"), | ||
] | ||
benchmark = Benchmark( | ||
name="test benchmark", | ||
train_dataset=train_dataset, | ||
test_dataset=test_dataset, | ||
evaluators=[AccuracyEvaluator()], | ||
) | ||
algorithms = [InContextLearning()] | ||
|
||
results = train_and_evaluate_benchmark( | ||
model=larger_model, | ||
tokenizer=larger_tokenizer, | ||
algorithms=algorithms, | ||
benchmark=benchmark, | ||
generation_config=GenerationConfig( | ||
max_length=100, pad_token_id=larger_tokenizer.eos_token_id | ||
), | ||
) | ||
|
||
assert len(results.predictions) == 3 | ||
assert results.predictions[0].example == test_dataset[0] | ||
assert results.predictions[1].example == test_dataset[1] | ||
assert results.predictions[2].example == test_dataset[2] | ||
assert results.metrics["accuracy"] == pytest.approx(2 / 3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters