Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: slim eval results and allowing multipler multipliers in cross-steering #159

Merged
merged 5 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 19 additions & 4 deletions repepo/core/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def set_repe_layer_hook(pipeline: Pipeline):

@dataclass
class EvalPrediction:
positive_output_prob: TextProbs
negative_output_prob: TextProbs
positive_output_prob: TextProbs | None
negative_output_prob: TextProbs | None
# Example-level metrics
metrics: dict[str, float]

Expand Down Expand Up @@ -169,6 +169,8 @@ def score_prediction(self, prediction: EvalPrediction) -> float:

# the output might be longer than the expected depending on how many tokens we generate
# so just verify that the expected output is a prefix of the generated output
assert prediction.positive_output_prob is not None
assert prediction.negative_output_prob is not None
positive_output_prob = prediction.positive_output_prob.sum_logprobs
negative_output_prob = prediction.negative_output_prob.sum_logprobs
return 1.0 if positive_output_prob > negative_output_prob else 0.0
Expand All @@ -189,6 +191,8 @@ def score_prediction(self, prediction: EvalPrediction) -> float:
"""Score a single prediction based on difference in sum of logits."""

# calculate difference in logits
assert prediction.positive_output_prob is not None
assert prediction.negative_output_prob is not None
positive_output_logit = prediction.positive_output_prob.sum_logits
negative_output_logit = prediction.negative_output_prob.sum_logits
return positive_output_logit - negative_output_logit
Expand All @@ -212,6 +216,8 @@ def score_prediction(self, prediction: EvalPrediction) -> float:
"""

# calculate normalized logprobs
assert prediction.positive_output_prob is not None
assert prediction.negative_output_prob is not None
positive_output_logprob = prediction.positive_output_prob.sum_logprobs
negative_output_logprob = prediction.negative_output_prob.sum_logprobs

Expand All @@ -236,6 +242,7 @@ def evaluate(
show_progress: bool = True,
tqdm_desc: str = "Evaluating",
logger: logging.Logger | None = None,
slim_results: bool = False, # if True, only return metrics for each example, not full token-level stats
) -> EvalResult:
# evaluate
predictions: list[EvalPrediction] = []
Expand All @@ -251,8 +258,12 @@ def evaluate(
logger.debug(
f"Example full prompt: \n {pipeline.build_full_prompt(example.positive)}"
)
positive_probs = pipeline.calculate_output_logprobs(example.positive)
negative_probs = pipeline.calculate_output_logprobs(example.negative)
positive_probs = pipeline.calculate_output_logprobs(
example.positive, slim_results=slim_results
)
negative_probs = pipeline.calculate_output_logprobs(
example.negative, slim_results=slim_results
)

pred = EvalPrediction(
positive_output_prob=positive_probs,
Expand All @@ -271,5 +282,9 @@ def evaluate(
dataset_metrics: dict[str, float] = {}
for evaluator in evaluators:
dataset_metrics.update(evaluator(predictions))
if slim_results:
for prediction in predictions:
prediction.positive_output_prob = None
prediction.negative_output_prob = None
return EvalResult(predictions, dataset_metrics)
raise RuntimeError("Should never get here")
84 changes: 47 additions & 37 deletions repepo/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,28 @@ class TokenProb:
# Note: the logit, logprob are for this token, not the next token
logprob: float
logit: float
text: str
text: str | None = None
# Metrics for logits of other tokens that were in this token position
logit_mean: float = float("nan")
logit_std: float = float("nan")
logit_skew: float = float("nan")
logit_kurtosis: float = float("nan")
logit_100_quantile: float = float("nan")
logit_75_quantile: float = float("nan")
logit_50_quantile: float = float("nan")
logit_25_quantile: float = float("nan")
logit_0_quantile: float = float("nan")
logit_mean: float | None = None
logit_std: float | None = None
logit_skew: float | None = None
logit_kurtosis: float | None = None
logit_100_quantile: float | None = None
logit_75_quantile: float | None = None
logit_50_quantile: float | None = None
logit_25_quantile: float | None = None
logit_0_quantile: float | None = None

@property
def logit_max(self) -> float:
if self.logit_100_quantile is None:
raise ValueError("logit_100_quantile is not set")
return self.logit_100_quantile

@property
def logit_min(self) -> float:
if self.logit_0_quantile is None:
raise ValueError("logit_0_quantile is not set")
return self.logit_0_quantile


Expand Down Expand Up @@ -115,7 +119,9 @@ def build_full_prompt(self, completion: Completion) -> str:
)

@torch.no_grad()
def calculate_output_logprobs(self, completion: Completion) -> TextProbs:
def calculate_output_logprobs(
self, completion: Completion, slim_results: bool = False
) -> TextProbs:
"""Calculate the logprobs for each token in the prompt + output"""
base_prompt = self.build_generation_prompt(completion)
full_prompt = self.build_full_prompt(completion)
Expand Down Expand Up @@ -154,36 +160,40 @@ def calculate_output_logprobs(self, completion: Completion) -> TextProbs:
# logits is of shape (1, seq_len, vocab_size)
assert logits.shape[0] == 1
logits = logits[0]
logit_moments = compute_moments(logits, dim=-1).cpu()
logit_quantiles = compute_quantiles(logits, dim=-1).cpu()
text_probs: list[TokenProb] = []

for token, logprob, logit, logit_moment, logit_quantile in zip(
target_ids[0].cpu(),
gen_logprobs,
gen_logits,
logit_moments,
logit_quantiles,
logit_moments = None
logit_quantiles = None
if not slim_results:
logit_moments = compute_moments(logits, dim=-1).cpu()
logit_quantiles = compute_quantiles(logits, dim=-1).cpu()

for i, (token, logprob, logit) in enumerate(
zip(
target_ids[0].cpu(),
gen_logprobs,
gen_logits,
)
):
if token not in self.tokenizer.all_special_ids:
text_probs.append(
TokenProb(
token_id=token.item(),
text=self.tokenizer.decode(token),
logprob=logprob.item(),
logit=logit.item(),
# moments
logit_mean=logit_moment[0].item(),
logit_std=logit_moment[1].item(),
logit_skew=logit_moment[2].item(),
logit_kurtosis=logit_moment[3].item(),
# quantiles
logit_0_quantile=logit_quantile[0].item(),
logit_25_quantile=logit_quantile[1].item(),
logit_50_quantile=logit_quantile[2].item(),
logit_75_quantile=logit_quantile[3].item(),
logit_100_quantile=logit_quantile[4].item(),
)
token_prob = TokenProb(
token_id=token.item(),
logprob=logprob.item(),
logit=logit.item(),
)
if not slim_results:
assert logit_moments is not None
assert logit_quantiles is not None
token_prob.text = self.tokenizer.decode(token)
token_prob.logit_mean = logit_moments[i, 0].item()
token_prob.logit_std = logit_moments[i, 1].item()
token_prob.logit_skew = logit_moments[i, 2].item()
token_prob.logit_kurtosis = logit_moments[i, 3].item()
token_prob.logit_0_quantile = logit_quantiles[i, 0].item()
token_prob.logit_25_quantile = logit_quantiles[i, 1].item()
token_prob.logit_50_quantile = logit_quantiles[i, 2].item()
token_prob.logit_75_quantile = logit_quantiles[i, 3].item()
token_prob.logit_100_quantile = logit_quantiles[i, 4].item()
text_probs.append(token_prob)
return TextProbs(text=full_prompt, token_probs=text_probs)
raise RuntimeError("Should never get here")
10 changes: 9 additions & 1 deletion repepo/data/translate/translate_mwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@
from typing import Any
from repepo.data.make_dataset import get_raw_dataset_dir
from repepo.data.utils import translate_row_recursive, collect_all_strings_recursive
from repepo.experiments.constants import MAIN_MWE_PERSONA_DATASETS

from repepo.translation.constants import LangOrStyleCode, LANG_OR_STYLE_MAPPING
from repepo.translation.lang_or_style_translate import lang_or_style_translate

MAIN_MWE_PERSONA_DATASETS = [
"believes-abortion-should-be-illegal",
"desire-for-recursive-self-improvement",
"willingness-to-be-non-HHH-to-be-deployed-in-the-real-world",
"machiavellianism",
"desire-to-persuade-people-to-be-less-harmful-to-others",
]


def translate(
dataset: list[Any],
Expand Down
61 changes: 16 additions & 45 deletions repepo/experiments/persona_generalization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass, replace
from dataclasses import dataclass, field, replace
import json
import os
from pathlib import Path
Expand All @@ -20,11 +20,6 @@
CrossSteeringResult,
evaluate_cross_steering,
)
from repepo.steering.plot_cross_steering_result import (
DeltaType,
DistMetric,
plot_cross_steering_result,
)
from repepo.steering.plot_steering_vector_cos_similarities import (
plot_steering_vector_cos_similarities,
)
Expand Down Expand Up @@ -141,13 +136,12 @@ def run_persona_cross_steering_experiment(
train_split: str,
test_split: str,
layer: int,
multipliers: list[float],
normalize_steering_magnitude_to_baseline: bool = True,
show_progress: bool = True,
patch_generation_tokens_only: bool = True,
skip_first_n_generation_tokens: int = 0,
completion_template: str | None = None,
positive_multiplier: float = 1.0,
negative_multiplier: float = -1.0,
) -> PersonaCrossSteeringExperimentResult:
steering_vectors: dict[str, SteeringVector] = {}
test_ds = make_dataset(dataset_name, test_split)
Expand Down Expand Up @@ -237,8 +231,7 @@ def _eval_build_pipeline(
patch_generation_tokens_only=patch_generation_tokens_only,
skip_first_n_generation_tokens=skip_first_n_generation_tokens,
completion_template=completion_template,
positive_multiplier=positive_multiplier,
negative_multiplier=negative_multiplier,
multipliers=multipliers,
)
return PersonaCrossSteeringExperimentResult(
dataset_name=dataset_name,
Expand All @@ -255,18 +248,22 @@ def plot_steering_on_dataset(
):
cs = result.cross_steering_result
ds_index = cs.dataset_labels.index(dataset_version)
ds_neg_steering = cs.neg_steering[ds_index]
ds_pos_steering = cs.pos_steering[ds_index]

multipliers = [cs.neg_multiplier, 0.0, cs.pos_multiplier]
multipliers = [*list(cs.neg_steering.keys()), 0.0, *list(cs.neg_steering.keys())]

results_line_mean = []
for i, label in enumerate(cs.steering_labels):
results_line_mean.append(
[
ds_neg_steering[i].metrics[metric_name],
*[
res[ds_index][i].metrics[metric_name]
for res in cs.neg_steering.values()
],
cs.dataset_baselines[ds_index].metrics[metric_name],
ds_pos_steering[i].metrics[metric_name],
*[
res[ds_index][i].metrics[metric_name]
for res in cs.pos_steering.values()
],
]
)

Expand Down Expand Up @@ -307,32 +304,6 @@ def plot_steering_on_dataset(
plt.show()


def plot_persona_cross_steering_result(
result: PersonaCrossSteeringExperimentResult,
delta_type: DeltaType = "pos_base",
dist_metric: DistMetric = "raw",
metric_name: str = "mean_pos_prob",
save_path: str | None = None,
):
cross_steering_result = replace(
result.cross_steering_result,
steering_labels=[
shorten(label) for label in result.cross_steering_result.steering_labels
],
dataset_labels=[
shorten(label) for label in result.cross_steering_result.dataset_labels
],
)
return plot_cross_steering_result(
cross_steering_result,
title=result.dataset_name,
delta_type=delta_type,
dist_metric=dist_metric,
metric_name=metric_name,
save_path=save_path,
)


def extract_layer(result: PersonaCrossSteeringExperimentResult) -> int:
return list(list(result.steering_vectors.values())[0].layer_activations.keys())[0]

Expand Down Expand Up @@ -409,8 +380,9 @@ class PersonaGeneralizationExperimentConfig:
train_split: str = "0:50%"
test_split: str = "50:100%"
layer: int = 13
positive_multiplier: float = 1.0
negative_multiplier: float = -1.0
multipliers: list[float] = field(
default_factory=lambda: [-1.5, -1.0, -0.5, 0.5, 1.0, 1.5]
)


def run_persona_generalization_experiment(
Expand Down Expand Up @@ -463,8 +435,7 @@ def run_persona_generalization_experiment(
patch_generation_tokens_only=config.patch_generation_tokens_only,
skip_first_n_generation_tokens=config.skip_first_n_generation_tokens,
completion_template=config.completion_template,
positive_multiplier=config.positive_multiplier,
negative_multiplier=config.negative_multiplier,
multipliers=config.multipliers,
)
torch.save(result, results_save_file)
print("Done!")
Expand Down
Loading
Loading