Skip to content

Commit

Permalink
Logprobs eval (#62)
Browse files Browse the repository at this point in the history
* porting dataset handling code from tqa branch

* adding logprob calculation and adding an evaluator for multiple choice questions
  • Loading branch information
chanind authored Jan 15, 2024
1 parent 2a2305e commit 8d07688
Show file tree
Hide file tree
Showing 14 changed files with 408 additions and 97 deletions.
10 changes: 6 additions & 4 deletions repepo/baselines/sft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@

from dataclasses import dataclass
from dataclasses import field
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, cast

import transformers
from transformers import Trainer, AutoTokenizer
from repepo.core.types import Tokenizer

from repepo.data import get_dataset
from repepo.data import utils
from repepo.data import get_dataset, utils
from repepo.data.dataset import sft
from repepo.variables import Environ
from repepo.variables import Model
Expand Down Expand Up @@ -58,7 +57,10 @@ class TrainingArguments(transformers.TrainingArguments):
def make_supervised_data_module(tokenizer: Tokenizer, data_args: DataArguments) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
list_data_dict = get_dataset(data_args.dataset_name)
train_dataset = sft.SupervisedDataset(list_data_dict, tokenizer=tokenizer)
# TODO: this looks incorrect, this is probably fixed in the sft branch
train_dataset = sft.SupervisedDataset(
cast(Any, list_data_dict), tokenizer=tokenizer
)
data_collator = sft.DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
Expand Down
30 changes: 27 additions & 3 deletions repepo/core/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pyright: strict, reportMissingTypeStubs=false

from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import Optional, Sequence

from transformers import GenerationConfig
Expand Down Expand Up @@ -38,10 +38,34 @@ def train_and_evaluate_benchmark(

# evaluate
predictions: list[EvalPrediction] = []
requires_generation = any([e.requires_generation for e in benchmark.evaluators])
requires_probs = any([e.requires_probs for e in benchmark.evaluators])
# TODO: support batching
for example in benchmark.test_dataset:
output = pipeline.generate(example, generation_config=generation_config)
predictions.append(EvalPrediction(example, output))
generated_output = None
correct_output_probs = None
incorrect_outputs_probs = None
if requires_generation:
generated_output = pipeline.generate(
example, generation_config=generation_config
)
if requires_probs:
correct_output_probs = pipeline.calculate_output_logprobs(example)
if example.incorrect_outputs is not None:
incorrect_outputs_probs = [
pipeline.calculate_output_logprobs(
replace(example, output=incorrect_output)
)
for incorrect_output in example.incorrect_outputs
]
predictions.append(
EvalPrediction(
example=example,
generated_output=generated_output,
correct_output_probs=correct_output_probs,
incorrect_outputs_probs=incorrect_outputs_probs,
)
)
metrics: dict[str, float] = {}
for evaluator in benchmark.evaluators:
metrics.update(evaluator(predictions))
Expand Down
63 changes: 58 additions & 5 deletions repepo/core/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
# pyright: strict


from abc import ABC, abstractmethod
from dataclasses import dataclass
from statistics import mean
from typing import Callable, Sequence
from typing import Optional, Sequence
from repepo.core.pipeline import TextProbs
from repepo.core.types import Example


@dataclass
class EvalPrediction:
example: Example
output: str
generated_output: Optional[str] = None
correct_output_probs: Optional[TextProbs] = None
incorrect_outputs_probs: Optional[list[TextProbs]] = None


@dataclass
Expand All @@ -19,18 +23,67 @@ class EvalResult:
metrics: dict[str, float]


Evaluator = Callable[[Sequence[EvalPrediction]], dict[str, float]]
class Evaluator(ABC):
requires_generation: bool = False
requires_probs: bool = False

@abstractmethod
def __call__(self, predictions: Sequence[EvalPrediction]) -> dict[str, float]:
raise NotImplementedError()


class AccuracyEvaluator(Evaluator):
"""
Evaluator that computes accuracy, i.e. the percentage of examples where the model
generated the correct output.
"""

requires_generation = True

class AccuracyEvaluator:
def score_prediction(self, prediction: EvalPrediction) -> float:
"""Score a single prediction, 1 if correct, 0 otherwise."""
expected = prediction.example.output
# the output might be longer than the expected depending on how many tokens we generate
# so just verify that the expected output is a prefix of the generated output
is_correct = prediction.output.strip().startswith(expected.strip())
assert prediction.generated_output is not None, "generation is required"
is_correct = prediction.generated_output.strip().startswith(expected.strip())
return 1.0 if is_correct else 0.0

def __call__(self, predictions: Sequence[EvalPrediction]) -> dict[str, float]:
pred_results = [self.score_prediction(pred) for pred in predictions]
return {"accuracy": mean(pred_results)}


class MultipleChoiceAccuracyEvaluator(Evaluator):
"""
Evaluator that scores multiple choice examples by computing the probability of
the correct output and comparing it to the probability of the incorrect outputs.
"""

requires_probs = True

def score_prediction(self, prediction: EvalPrediction) -> float:
"""Score a single prediction, 1 if correct, 0 otherwise."""
if (
prediction.example.incorrect_outputs is None
or len(prediction.example.incorrect_outputs) == 0
):
raise ValueError(
"Multiple choice evaluation requires examples to set incorrect_outputs"
)
# the output might be longer than the expected depending on how many tokens we generate
# so just verify that the expected output is a prefix of the generated output
assert prediction.correct_output_probs is not None, "output probs are required"
assert (
prediction.incorrect_outputs_probs is not None
), "output probs are required"
correct_prob = prediction.correct_output_probs.sum_logprobs
incorrect_probs = [
incorrect_output_probs.sum_logprobs
for incorrect_output_probs in prediction.incorrect_outputs_probs
]
return 1.0 if correct_prob > max(incorrect_probs) else 0.0

def __call__(self, predictions: Sequence[EvalPrediction]) -> dict[str, float]:
pred_results = [self.score_prediction(pred) for pred in predictions]
return {"accuracy": mean(pred_results)}
45 changes: 45 additions & 0 deletions repepo/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,33 @@
from typing import Any, Optional

from transformers import GenerationConfig
import torch

from .types import Example, Model, Tokenizer
from .prompt import Prompter, IdentityPrompter
from .format import Formatter, InputOutputFormatter


@dataclass
class TokenProb:
token_id: int
logprob: float
text: str


@dataclass
class TextProbs:
text: str
token_probs: list[TokenProb]

@property
def sum_logprobs(self) -> float:
return sum([tp.logprob for tp in self.token_probs])

def __repr__(self) -> str:
return f"TextProbs({self.text}:{self.sum_logprobs:.2f})"


@dataclass
class Pipeline:
"""Generation pipeline"""
Expand Down Expand Up @@ -41,3 +62,27 @@ def generate(
if remove_base_prompt:
return outputs_str.replace(base_prompt, "")
return outputs_str

def calculate_output_logprobs(self, example: Example) -> TextProbs:
"""Calculate the logprobs for each token in the prompt + output"""
base_prompt = self.build_generation_prompt(example)
full_prompt = base_prompt + example.output
inputs: Any = self.tokenizer(full_prompt, return_tensors="pt")
inputs = inputs.to(self.model.device)
outputs = self.model(**inputs, output_hidden_states=False, return_dict=True)
probs = torch.log_softmax(outputs.logits, dim=-1).detach().cpu()
# collect the probability of the generated token -- probability at index 0 corresponds to the token at index 1
probs = probs[:, :-1, :]
target_ids = inputs.input_ids[:, 1:].cpu()
gen_probs = torch.gather(probs, 2, target_ids[:, :, None]).squeeze(-1)[0]
text_probs: list[TokenProb] = []
for token, p in zip(target_ids[0], gen_probs):
if token not in self.tokenizer.all_special_ids:
text_probs.append(
TokenProb(
token_id=token.item(),
text=self.tokenizer.decode(token),
logprob=p.item(),
)
)
return TextProbs(text=full_prompt, token_probs=text_probs)
92 changes: 15 additions & 77 deletions repepo/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,15 @@
import functools
import pathlib
import random
from typing import List, Any, NewType

from repepo.variables import Environ
from repepo.core.types import Example
from dataclasses import dataclass

from .io import jdump
from .io import jload


def get_dataset_dir() -> pathlib.Path:
return pathlib.Path(Environ.DatasetDir)


def get_all_json_filepaths(root_dir: pathlib.Path) -> List[pathlib.Path]:
return list(root_dir.rglob("*.json"))


_DATASETS = {}
for path in get_all_json_filepaths(get_dataset_dir()):
_DATASETS[path.stem] = path.absolute()


@functools.lru_cache(1)
def list_datasets() -> tuple[str, ...]:
return tuple(_DATASETS.keys())


@dataclass
class DatasetSpec:
name: str
split: str = ":100%"
seed: int = 0 # unused atm


def parse_split(split_string, start_index, end_index) -> Any:
# Check if the split string starts with a colon and ends with a percentage sign
if split_string.startswith(":") and split_string.endswith("%"):
try:
# Remove the colon and percentage sign
split_value = float(split_string[1:-1])
# Calculate the start and end index of the split
split_start = start_index
split_end = start_index + int(
(end_index - start_index + 1) * (split_value / 100)
)
return slice(split_start, split_end)

except ValueError as e:
# TODO: Write a nice error message
raise ValueError(e)

else:
# Invalid format, return None
raise ValueError(f"Parse string {split_string} not recognized")


@functools.lru_cache(1)
def get_dataset(name: str):
if name not in _DATASETS:
raise ValueError(f"Unknown dataset: {name}")

example_dict_list = jload(_DATASETS[name])
dataset = []
for example_dict in example_dict_list:
dataset.append(Example(**example_dict))
return dataset


def make_dataset(spec: DatasetSpec):
dataset = get_dataset(spec.name)
random.Random(spec.seed).shuffle(dataset) # in-place shuffle
split = parse_split(spec.split, 0, len(dataset))
return dataset[split]
from .make_dataset import (
DatasetSpec,
list_datasets,
make_dataset,
get_dataset,
get_dataset_dir,
)

__all__ = [
"DatasetSpec",
"get_dataset",
"get_dataset_dir",
"list_datasets",
"make_dataset",
]
5 changes: 2 additions & 3 deletions repepo/data/generation/icl_task_vectors/make_all.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pathlib
from typing import Dict, List, Any

from repepo.data import get_dataset_dir
from repepo.data import jdump
from repepo.data import jload
from repepo.data.make_dataset import get_dataset_dir
from repepo.data.io import jdump, jload
from repepo.data.generation.icl_task_vectors.make_knowledge_data import (
prepare_knowledge_data,
)
Expand Down
2 changes: 1 addition & 1 deletion repepo/data/generation/make_stereoset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datasets import load_dataset, DatasetDict

from repepo.data import get_dataset_dir
from repepo.data import jdump
from repepo.data.io import jdump

POS_INSTRUCTION = "Select the stereotypical example"
NEG_INSTRUCTION = "Select the anti-stereotypical example"
Expand Down
2 changes: 1 addition & 1 deletion repepo/data/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def jdump(obj, f, mode="w", indent=4):
f.close()


def jload(f, mode="r"):
def jload(f, mode="r") -> Any:
"""Load a .json file into a dictionary."""
f = _make_r_io_base(f, mode)
jdict = json.load(f)
Expand Down
Loading

0 comments on commit 8d07688

Please sign in to comment.