Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generative metric #169

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 27 additions & 26 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,33 @@ authors = [
classifiers = ["Private :: Do Not Upload"]
requires-python = ">=3.10,<4.0"
dependencies = [
"pyrallis>=0.3.1",
"wandb>=0.16.0",
"transformers>=4.37.0",
"torch>=2.1.1",
"datasets>=2.15.0",
"numpy>=1.26.2",
"accelerate>=0.30.0",
"scikit-learn>=1.3.2",
"evaluate>=0.4.1",
"rouge-score>=0.1.2",
"termcolor>=2.4.0",
"bitsandbytes==0.42.0",
"nbdime>=4.0.1",
"steering-vectors>=0.10.0",
"openai>=1.10.0",
"arrr>=1.0.4",
"spacy>=3.7.2",
"mosestokenizer>=1.2.1",
"gradio>=4.16.0",
"simple-parsing>=0.1.5",
"torchmetrics>=1.3.0.post0",
"umap-learn>=0.5.5",
"google-cloud-translate>=3.15.1",
"statsmodels>=0.14.1",
"jaxtyping>=0.2.28",
"concept-erasure>=0.2.4",
"pyrallis>=0.3.1",
"wandb>=0.16.0",
"transformers>=4.37.0",
"torch>=2.1.1",
"datasets>=2.15.0",
"numpy>=1.26.2",
"accelerate>=0.30.0",
"scikit-learn>=1.3.2",
"evaluate>=0.4.1",
"rouge-score>=0.1.2",
"termcolor>=2.4.0",
"bitsandbytes==0.42.0",
"nbdime>=4.0.1",
"steering-vectors>=0.10.0",
"openai>=1.10.0",
"arrr>=1.0.4",
"spacy>=3.7.2",
"mosestokenizer>=1.2.1",
"gradio>=4.16.0",
"simple-parsing>=0.1.5",
"torchmetrics>=1.3.0.post0",
"umap-learn>=0.5.5",
"google-cloud-translate>=3.15.1",
"statsmodels>=0.14.1",
"jaxtyping>=0.2.28",
"concept-erasure>=0.2.4",
"python-dotenv>=1.0.1",
]

[tool.black]
Expand Down
75 changes: 56 additions & 19 deletions repepo/core/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

from abc import ABC, abstractmethod
from contextlib import AbstractContextManager, ExitStack, contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from statistics import mean, stdev, StatisticsError
from tqdm import tqdm
from typing import Callable, Iterable, Sequence
from repepo.core.hook import SteeringHook
from repepo.core.pipeline import TextProbs
from repepo.core.types import Example
from repepo.core.types import Example, Completion
from repepo.core.pipeline import Pipeline

import numpy as np
Expand Down Expand Up @@ -115,6 +115,7 @@ class EvalPrediction:
negative_output_prob: TextProbs | None
# Example-level metrics
metrics: dict[str, float]
generations: list[Completion] = field(default_factory=list)


@dataclass
Expand Down Expand Up @@ -232,12 +233,59 @@ def score_prediction(self, prediction: EvalPrediction) -> float:
return positive_output_prob / (positive_output_prob + negative_output_prob)


def get_eval_hooks(
layer_id: int,
multiplier: float = 0,
completion_template: str | None = None,
system_prompt: str | None = None,
) -> list[EvalHook]:
eval_hooks = [
set_repe_direction_multiplier_at_eval(multiplier),
select_repe_layer_at_eval(layer_id),
]
if completion_template is not None:
eval_hooks.append(update_completion_template_at_eval(completion_template))
if system_prompt is not None:
eval_hooks.append(update_system_prompt_at_eval(system_prompt))
return eval_hooks


# Function to get a single prediction
def get_prediction(
pipeline: Pipeline,
example: Example,
evaluators: Sequence[Evaluator] = [LogitDifferenceEvaluator()],
slim_results: bool = False,
) -> EvalPrediction:
positive_probs = pipeline.calculate_output_logprobs(
example.positive, slim_results=slim_results
)
negative_probs = pipeline.calculate_output_logprobs(
example.negative, slim_results=slim_results
)
pred = EvalPrediction(
positive_output_prob=positive_probs,
negative_output_prob=negative_probs,
metrics={},
)
example_metrics = {}
for evaluator in evaluators:
example_metrics[evaluator.get_metric_name()] = evaluator.score_prediction(pred)
# for _ in range(n_generation):
# prompt = pipeline.build_generation_prompt(example.positive)
# response = pipeline.generate(example.positive)
# pred.generations.append(Completion(prompt=prompt, response=response))
pred.metrics = example_metrics
return pred


def evaluate(
pipeline: Pipeline,
dataset: Iterable[Example],
evaluators: Sequence[Evaluator],
# these eval_hooks allow us to do custom stuff to the pipeline only during evaluation,
# e.g. mess with the formatter to use CAA's special answer format
n_generation: int = 0,
eval_hooks: Sequence[EvalHook] = [],
show_progress: bool = True,
tqdm_desc: str = "Evaluating",
Expand All @@ -258,25 +306,14 @@ def evaluate(
logger.debug(
f"Example full prompt: \n {pipeline.build_full_prompt(example.positive)}"
)
positive_probs = pipeline.calculate_output_logprobs(
example.positive, slim_results=slim_results
)
negative_probs = pipeline.calculate_output_logprobs(
example.negative, slim_results=slim_results
)

pred = EvalPrediction(
positive_output_prob=positive_probs,
negative_output_prob=negative_probs,
metrics={},
pred = get_prediction(
pipeline,
example,
evaluators=evaluators,
n_generation=n_generation,
slim_results=slim_results,
)
example_metrics = {}
for evaluator in evaluators:
example_metrics[evaluator.get_metric_name()] = (
evaluator.score_prediction(pred)
)
pred.metrics = example_metrics

predictions.append(pred)

dataset_metrics: dict[str, float] = {}
Expand Down
27 changes: 26 additions & 1 deletion repepo/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __repr__(self) -> str:

@dataclass
class PipelineContext:
method: Literal["generate", "logprobs"]
method: Literal["generate", "logprobs", "forward"]
base_prompt: str
full_prompt: str
inputs: Any
Expand Down Expand Up @@ -118,6 +118,7 @@ def build_full_prompt(self, completion: Completion) -> str:
self.formatter.format_conversation(completion, self.conversation_history)
)

@torch.no_grad()
def generate(
self,
completion: Completion,
Expand Down Expand Up @@ -147,12 +148,36 @@ def generate(
return outputs_str.replace(base_prompt, "")
return outputs_str
raise RuntimeError("Should never get here")

@torch.no_grad()
def __call__(self, text: str):
"""A lightweight wrapper around model"""
# with ExitStack() as stack:
# for hook in self.hooks:
# stack.enter_context(hook(context))
# base_prompt = self.build_generation_prompt(completion)
# full_prompt = self.build_full_prompt(completion)
inputs: Any = self.tokenizer(text, return_tensors="pt")
inputs = inputs.to(self.model.device)
context = PipelineContext(
method="forward",
base_prompt="dummy",
full_prompt="dummy",
inputs=inputs,
pipeline=self,
)
with ExitStack() as stack:
for hook in self.hooks:
stack.enter_context(hook(context))
outputs = self.model(**inputs, output_hidden_states=False, return_dict=True)
return outputs

@torch.no_grad()
def calculate_output_logprobs(
self, completion: Completion, slim_results: bool = False
) -> TextProbs:
"""Calculate the logprobs for each token in the prompt + output"""

base_prompt = self.build_generation_prompt(completion)
full_prompt = self.build_full_prompt(completion)
inputs: Any = self.tokenizer(full_prompt, return_tensors="pt")
Expand Down
41 changes: 41 additions & 0 deletions repepo/data/multiple_choice/make_mwe_persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,35 @@ def convert_mwe_personas_dataset_caa(
return mwe_dataset


def convert_mwe_personas_dataset_gen(
mwe: list[dict[str, str]],
# lang_or_style: LangOrStyleCode | None = None,
# ctx: TS | None = None,
) -> Dataset:
"""Convert a dataset in MWE format to our format"""
mwe_dataset: Dataset = []
for element in mwe:
prompt = element["question"]

# if ctx is not None:
# ctx_str = load_translation(ctx, lang_or_style)
# prompt = f"{ctx_str}\n\n{prompt}"

positive = Completion(
prompt=prompt,
response=element["answer_matching_behavior"],
)

negative = Completion(
prompt=prompt,
response=element["answer_not_matching_behavior"],
)

ex = Example(positive=positive, negative=negative, steering_token_index=-2)
mwe_dataset.append(ex)
return mwe_dataset


def make_mwe_personas_caa():
"""Make MWE dataset"""
for dataset_path in get_raw_dataset_dir().glob("mwe/persona/*.jsonl"):
Expand All @@ -110,6 +139,18 @@ def make_mwe_personas_caa():
jdump(mwe_dataset, get_dataset_dir() / "persona" / filename)


def make_mwe_personas_gen():
"""Make MWE dataset"""
for dataset_path in get_raw_dataset_dir().glob("mwe/persona/*.jsonl"):
with open(dataset_path, "r") as jsonfile:
list_dataset = [json.loads(line) for line in jsonfile]

persona = dataset_path.stem
filename = build_dataset_filename(persona + "_gen")
mwe_dataset: Dataset = convert_mwe_personas_dataset_gen(list_dataset)
jdump(mwe_dataset, get_dataset_dir() / "persona" / filename)


def make_mwe_personas_caa_ctx():
"""Make MWE dataset with generation contexts prepended"""
for dataset_path in get_raw_dataset_dir().glob("mwe/persona/*.jsonl"):
Expand Down
14 changes: 14 additions & 0 deletions repepo/notebooks/download_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

if __name__ == "__main__":
torch.set_grad_enabled(False)

model_name: str = "meta-llama/Llama-2-7b-chat-hf"
device: str = "cuda"

model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, device_map=0
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
Loading
Loading