From 43b38906bdbf3b977fe9b1bcd749faa90bb67e1b Mon Sep 17 00:00:00 2001 From: David Chanin Date: Mon, 22 Apr 2024 21:14:24 +0100 Subject: [PATCH 1/3] adding slim eval results and allowing multipler multipliers in cross-steering --- pdm.lock | 2 +- repepo/core/evaluate.py | 19 +++- repepo/core/pipeline.py | 87 +++++++------- repepo/experiments/persona_generalization.py | 61 +++------- repepo/steering/evaluate_cross_steering.py | 41 +++---- repepo/steering/evaluate_steering_vector.py | 2 + repepo/steering/plot_cross_steering_result.py | 106 ------------------ .../test_persona_generalization.py | 15 +-- 8 files changed, 106 insertions(+), 227 deletions(-) delete mode 100644 repepo/steering/plot_cross_steering_result.py diff --git a/pdm.lock b/pdm.lock index a364c49d..836f84b2 100644 --- a/pdm.lock +++ b/pdm.lock @@ -4,7 +4,7 @@ [metadata] groups = ["default", "dev"] strategy = ["cross_platform"] -lock_version = "4.4.1" +lock_version = "4.4" content_hash = "sha256:28bfaa0122acd22674f70e4f2404b96d943ec1b911b129ed64dbf3f1c5f96648" [[package]] diff --git a/repepo/core/evaluate.py b/repepo/core/evaluate.py index 04486753..c2bdc916 100644 --- a/repepo/core/evaluate.py +++ b/repepo/core/evaluate.py @@ -111,8 +111,8 @@ def set_repe_layer_hook(pipeline: Pipeline): @dataclass class EvalPrediction: - positive_output_prob: TextProbs - negative_output_prob: TextProbs + positive_output_prob: TextProbs | None + negative_output_prob: TextProbs | None # Example-level metrics metrics: dict[str, float] @@ -212,6 +212,8 @@ def score_prediction(self, prediction: EvalPrediction) -> float: """ # calculate normalized logprobs + assert prediction.positive_output_prob is not None + assert prediction.negative_output_prob is not None positive_output_logprob = prediction.positive_output_prob.sum_logprobs negative_output_logprob = prediction.negative_output_prob.sum_logprobs @@ -236,6 +238,7 @@ def evaluate( show_progress: bool = True, tqdm_desc: str = "Evaluating", logger: logging.Logger | None = None, + slim_results: bool = False, # if True, only return metrics for each example, not full token-level stats ) -> EvalResult: # evaluate predictions: list[EvalPrediction] = [] @@ -251,8 +254,12 @@ def evaluate( logger.debug( f"Example full prompt: \n {pipeline.build_full_prompt(example.positive)}" ) - positive_probs = pipeline.calculate_output_logprobs(example.positive) - negative_probs = pipeline.calculate_output_logprobs(example.negative) + positive_probs = pipeline.calculate_output_logprobs( + example.positive, slim_results=slim_results + ) + negative_probs = pipeline.calculate_output_logprobs( + example.negative, slim_results=slim_results + ) pred = EvalPrediction( positive_output_prob=positive_probs, @@ -271,5 +278,9 @@ def evaluate( dataset_metrics: dict[str, float] = {} for evaluator in evaluators: dataset_metrics.update(evaluator(predictions)) + if slim_results: + for prediction in predictions: + prediction.positive_output_prob = None + prediction.negative_output_prob = None return EvalResult(predictions, dataset_metrics) raise RuntimeError("Should never get here") diff --git a/repepo/core/pipeline.py b/repepo/core/pipeline.py index a304d5df..a76714cc 100644 --- a/repepo/core/pipeline.py +++ b/repepo/core/pipeline.py @@ -14,24 +14,28 @@ class TokenProb: # Note: the logit, logprob are for this token, not the next token logprob: float logit: float - text: str + text: str | None = None # Metrics for logits of other tokens that were in this token position - logit_mean: float = float("nan") - logit_std: float = float("nan") - logit_skew: float = float("nan") - logit_kurtosis: float = float("nan") - logit_100_quantile: float = float("nan") - logit_75_quantile: float = float("nan") - logit_50_quantile: float = float("nan") - logit_25_quantile: float = float("nan") - logit_0_quantile: float = float("nan") + logit_mean: float | None = None + logit_std: float | None = None + logit_skew: float | None = None + logit_kurtosis: float | None = None + logit_100_quantile: float | None = None + logit_75_quantile: float | None = None + logit_50_quantile: float | None = None + logit_25_quantile: float | None = None + logit_0_quantile: float | None = None @property def logit_max(self) -> float: + if self.logit_100_quantile is None: + raise ValueError("logit_100_quantile is not set") return self.logit_100_quantile @property def logit_min(self) -> float: + if self.logit_0_quantile is None: + raise ValueError("logit_0_quantile is not set") return self.logit_0_quantile @@ -62,8 +66,7 @@ class PipelineContext: class PipelineHook(Protocol): - def __call__(self, context: PipelineContext) -> AbstractContextManager[None]: - ... + def __call__(self, context: PipelineContext) -> AbstractContextManager[None]: ... def compute_moments(tensor: torch.Tensor, dim: int) -> torch.Tensor: @@ -115,7 +118,9 @@ def build_full_prompt(self, completion: Completion) -> str: ) @torch.no_grad() - def calculate_output_logprobs(self, completion: Completion) -> TextProbs: + def calculate_output_logprobs( + self, completion: Completion, slim_results: bool = False + ) -> TextProbs: """Calculate the logprobs for each token in the prompt + output""" base_prompt = self.build_generation_prompt(completion) full_prompt = self.build_full_prompt(completion) @@ -154,36 +159,40 @@ def calculate_output_logprobs(self, completion: Completion) -> TextProbs: # logits is of shape (1, seq_len, vocab_size) assert logits.shape[0] == 1 logits = logits[0] - logit_moments = compute_moments(logits, dim=-1).cpu() - logit_quantiles = compute_quantiles(logits, dim=-1).cpu() text_probs: list[TokenProb] = [] - for token, logprob, logit, logit_moment, logit_quantile in zip( - target_ids[0].cpu(), - gen_logprobs, - gen_logits, - logit_moments, - logit_quantiles, + logit_moments = None + logit_quantiles = None + if not slim_results: + logit_moments = compute_moments(logits, dim=-1).cpu() + logit_quantiles = compute_quantiles(logits, dim=-1).cpu() + + for i, (token, logprob, logit) in enumerate( + zip( + target_ids[0].cpu(), + gen_logprobs, + gen_logits, + ) ): if token not in self.tokenizer.all_special_ids: - text_probs.append( - TokenProb( - token_id=token.item(), - text=self.tokenizer.decode(token), - logprob=logprob.item(), - logit=logit.item(), - # moments - logit_mean=logit_moment[0].item(), - logit_std=logit_moment[1].item(), - logit_skew=logit_moment[2].item(), - logit_kurtosis=logit_moment[3].item(), - # quantiles - logit_0_quantile=logit_quantile[0].item(), - logit_25_quantile=logit_quantile[1].item(), - logit_50_quantile=logit_quantile[2].item(), - logit_75_quantile=logit_quantile[3].item(), - logit_100_quantile=logit_quantile[4].item(), - ) + token_prob = TokenProb( + token_id=token.item(), + logprob=logprob.item(), + logit=logit.item(), ) + if not slim_results: + assert logit_moments is not None + assert logit_quantiles is not None + token_prob.text = self.tokenizer.decode(token) + token_prob.logit_mean = logit_moments[i, 0].item() + token_prob.logit_std = logit_moments[i, 1].item() + token_prob.logit_skew = logit_moments[i, 2].item() + token_prob.logit_kurtosis = logit_moments[i, 3].item() + token_prob.logit_0_quantile = logit_quantiles[i, 0].item() + token_prob.logit_25_quantile = logit_quantiles[i, 1].item() + token_prob.logit_50_quantile = logit_quantiles[i, 2].item() + token_prob.logit_75_quantile = logit_quantiles[i, 3].item() + token_prob.logit_100_quantile = logit_quantiles[i, 4].item() + text_probs.append(token_prob) return TextProbs(text=full_prompt, token_probs=text_probs) raise RuntimeError("Should never get here") diff --git a/repepo/experiments/persona_generalization.py b/repepo/experiments/persona_generalization.py index 41ef30b6..bbbb13d1 100644 --- a/repepo/experiments/persona_generalization.py +++ b/repepo/experiments/persona_generalization.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, replace +from dataclasses import dataclass, field, replace import json import os from pathlib import Path @@ -20,11 +20,6 @@ CrossSteeringResult, evaluate_cross_steering, ) -from repepo.steering.plot_cross_steering_result import ( - DeltaType, - DistMetric, - plot_cross_steering_result, -) from repepo.steering.plot_steering_vector_cos_similarities import ( plot_steering_vector_cos_similarities, ) @@ -217,13 +212,12 @@ def run_persona_cross_steering_experiment( train_split: str, test_split: str, layer: int, + multipliers: list[float], normalize_steering_magnitude_to_baseline: bool = True, show_progress: bool = True, patch_generation_tokens_only: bool = True, skip_first_n_generation_tokens: int = 0, completion_template: str | None = None, - positive_multiplier: float = 1.0, - negative_multiplier: float = -1.0, ) -> PersonaCrossSteeringExperimentResult: steering_vectors: dict[str, SteeringVector] = {} test_ds = make_dataset(dataset_name, test_split) @@ -313,8 +307,7 @@ def _eval_build_pipeline( patch_generation_tokens_only=patch_generation_tokens_only, skip_first_n_generation_tokens=skip_first_n_generation_tokens, completion_template=completion_template, - positive_multiplier=positive_multiplier, - negative_multiplier=negative_multiplier, + multipliers=multipliers, ) return PersonaCrossSteeringExperimentResult( dataset_name=dataset_name, @@ -331,18 +324,22 @@ def plot_steering_on_dataset( ): cs = result.cross_steering_result ds_index = cs.dataset_labels.index(dataset_version) - ds_neg_steering = cs.neg_steering[ds_index] - ds_pos_steering = cs.pos_steering[ds_index] - multipliers = [cs.neg_multiplier, 0.0, cs.pos_multiplier] + multipliers = [*list(cs.neg_steering.keys()), 0.0, *list(cs.neg_steering.keys())] results_line_mean = [] for i, label in enumerate(cs.steering_labels): results_line_mean.append( [ - ds_neg_steering[i].metrics[metric_name], + *[ + res[ds_index][i].metrics[metric_name] + for res in cs.neg_steering.values() + ], cs.dataset_baselines[ds_index].metrics[metric_name], - ds_pos_steering[i].metrics[metric_name], + *[ + res[ds_index][i].metrics[metric_name] + for res in cs.pos_steering.values() + ], ] ) @@ -383,32 +380,6 @@ def plot_steering_on_dataset( plt.show() -def plot_persona_cross_steering_result( - result: PersonaCrossSteeringExperimentResult, - delta_type: DeltaType = "pos_base", - dist_metric: DistMetric = "raw", - metric_name: str = "mean_pos_prob", - save_path: str | None = None, -): - cross_steering_result = replace( - result.cross_steering_result, - steering_labels=[ - shorten(label) for label in result.cross_steering_result.steering_labels - ], - dataset_labels=[ - shorten(label) for label in result.cross_steering_result.dataset_labels - ], - ) - return plot_cross_steering_result( - cross_steering_result, - title=result.dataset_name, - delta_type=delta_type, - dist_metric=dist_metric, - metric_name=metric_name, - save_path=save_path, - ) - - def extract_layer(result: PersonaCrossSteeringExperimentResult) -> int: return list(list(result.steering_vectors.values())[0].layer_activations.keys())[0] @@ -485,8 +456,9 @@ class PersonaGeneralizationExperimentConfig: train_split: str = "0:50%" test_split: str = "50:100%" layer: int = 15 - positive_multiplier: float = 1.0 - negative_multiplier: float = -1.0 + multipliers: list[float] = field( + default_factory=lambda: [-1.5, -1.0, -0.5, 0.5, 1.0, 1.5] + ) def run_persona_generalization_experiment( @@ -538,8 +510,7 @@ def run_persona_generalization_experiment( patch_generation_tokens_only=config.patch_generation_tokens_only, skip_first_n_generation_tokens=config.skip_first_n_generation_tokens, completion_template=config.completion_template, - positive_multiplier=config.positive_multiplier, - negative_multiplier=config.negative_multiplier, + multipliers=config.multipliers, ) torch.save(result, results_save_file) print("Done!") diff --git a/repepo/steering/evaluate_cross_steering.py b/repepo/steering/evaluate_cross_steering.py index 0a8284ca..0ecb6cac 100644 --- a/repepo/steering/evaluate_cross_steering.py +++ b/repepo/steering/evaluate_cross_steering.py @@ -1,3 +1,4 @@ +from collections import defaultdict from dataclasses import dataclass from typing import Any, Callable, NamedTuple @@ -25,10 +26,15 @@ class CrossSteeringResult: steering_labels: list[str] dataset_labels: list[str] dataset_baselines: list[EvalResult] - pos_steering: list[list[EvalResult]] - neg_steering: list[list[EvalResult]] - pos_multiplier: float - neg_multiplier: float + steering: dict[float, list[list[EvalResult]]] + + @property + def neg_steering(self) -> dict[float, list[list[EvalResult]]]: + return {k: v for k, v in self.steering.items() if k < 0} + + @property + def pos_steering(self) -> dict[float, list[list[EvalResult]]]: + return {k: v for k, v in self.steering.items() if k > 0} def evaluate_cross_steering( @@ -37,9 +43,8 @@ def evaluate_cross_steering( layer: int, steering_vectors: dict[str, SteeringVector], datasets: dict[str, Dataset], + multipliers: list[float], build_pipeline: Callable[[Model, Tokenizer, str], Any] | None = None, - positive_multiplier: float = 1.0, - negative_multiplier: float = -1.0, patch_generation_tokens_only: bool = True, skip_first_n_generation_tokens: int = 0, completion_template: str | None = None, @@ -58,8 +63,7 @@ def evaluate_cross_steering( # Get baseline logits baseline_results = [] - pos_steering = [] - neg_steering = [] + steering: dict[float, list[list[EvalResult]]] = defaultdict(list) pbar = tqdm( total=len(dataset_labels) * len(steering_labels), desc="Evaluating cross-steering", @@ -70,8 +74,7 @@ def evaluate_cross_steering( first_sv = list(steering_vectors.values())[0] for dataset_label in dataset_labels: - dataset_pos_steering = [] - dataset_neg_steering = [] + dataset_steering: dict[float, list[EvalResult]] = defaultdict(list) dataset = datasets[dataset_label] pipeline = build_pipeline(model, tokenizer, dataset_label) result = evaluate_steering_vector( @@ -92,12 +95,12 @@ def evaluate_cross_steering( baseline_results.append(result) for steering_label in steering_labels: steering_vector = steering_vectors[steering_label] - neg_result, pos_result = evaluate_steering_vector( + results = evaluate_steering_vector( pipeline, steering_vector, dataset, layers=[layer], - multipliers=[negative_multiplier, positive_multiplier], + multipliers=[mul for mul in multipliers if mul != 0], evaluators=[ NormalizedPositiveProbabilityEvaluator(), LogitDifferenceEvaluator(), @@ -106,19 +109,17 @@ def evaluate_cross_steering( skip_first_n_generation_tokens=skip_first_n_generation_tokens, completion_template=completion_template, show_progress=False, + slim_results=True, ) - dataset_neg_steering.append(neg_result) - dataset_pos_steering.append(pos_result) + for result, multiplier in zip(results, multipliers): + dataset_steering[multiplier].append(result) pbar.update(1) - pos_steering.append(dataset_pos_steering) - neg_steering.append(dataset_neg_steering) + for multiplier, results in dataset_steering.items(): + steering[multiplier].append(results) return CrossSteeringResult( steering_labels=steering_labels, dataset_labels=dataset_labels, dataset_baselines=baseline_results, - pos_steering=pos_steering, - neg_steering=neg_steering, - pos_multiplier=positive_multiplier, - neg_multiplier=negative_multiplier, + steering=steering, ) diff --git a/repepo/steering/evaluate_steering_vector.py b/repepo/steering/evaluate_steering_vector.py index fd47bc3f..9771297e 100644 --- a/repepo/steering/evaluate_steering_vector.py +++ b/repepo/steering/evaluate_steering_vector.py @@ -35,6 +35,7 @@ def evaluate_steering_vector( NormalizedPositiveProbabilityEvaluator(), ], show_progress: bool = True, + slim_results: bool = False, ) -> list[EvalResult]: results = [] @@ -69,6 +70,7 @@ def evaluate_steering_vector( evaluators=evaluators, logger=logger, show_progress=show_progress, + slim_results=slim_results, ) results.append(result) if logger is not None: diff --git a/repepo/steering/plot_cross_steering_result.py b/repepo/steering/plot_cross_steering_result.py deleted file mode 100644 index 009d051c..00000000 --- a/repepo/steering/plot_cross_steering_result.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import Any, Literal -import torch -import seaborn as sns -import numpy as np -import matplotlib.pyplot as plt -from repepo.steering.evaluate_cross_steering import ( - CrossSteeringResult, -) -from matplotlib import patheffects - -from repepo.utils.stats import bernoulli_js_dist - - -DeltaType = Literal["pos_base", "base_neg", "pos_neg"] -DistMetric = Literal["js", "raw"] - - -def _calc_deltas( - result: CrossSteeringResult, - delta_type: DeltaType, - dist_metric: DistMetric, - metric_name: str, -) -> list[list[float]]: - dist_fn = bernoulli_js_dist if dist_metric == "js" else lambda x, y: y - x - deltas = [] - for baseline, dataset_pos_steering, dataset_neg_steering in zip( - result.dataset_baselines, result.pos_steering, result.neg_steering - ): - if delta_type == "pos_base": - ds_deltas = [ - dist_fn(baseline.metrics[metric_name], steering.metrics[metric_name]) - for steering in dataset_pos_steering - ] - elif delta_type == "base_neg": - ds_deltas = [ - dist_fn(steering.metrics[metric_name], baseline.metrics[metric_name]) - for steering in dataset_neg_steering - ] - elif delta_type == "pos_neg": - ds_deltas = [ - dist_fn( - neg_steering.metrics[metric_name], pos_steering.metrics[metric_name] - ) - for pos_steering, neg_steering in zip( - dataset_pos_steering, dataset_neg_steering - ) - ] - deltas.append(ds_deltas) - return deltas - - -def plot_cross_steering_result( - result: CrossSteeringResult, - title: str, - delta_type: DeltaType = "pos_base", - dist_metric: DistMetric = "raw", - metric_name: str = "mean_pos_prob", - save_path: str | None = None, -): - deltas = _calc_deltas( - result, delta_type, dist_metric=dist_metric, metric_name=metric_name - ) - - deltas_tensor = torch.tensor(deltas) - largest_abs_val = deltas_tensor.abs().max().item() - sns.heatmap( - deltas_tensor, - center=0, - cmap="RdBu_r", - vmin=-1 * largest_abs_val, - vmax=largest_abs_val, - ) - - # Iterate over the data and create a text annotation for each cell - for i in range(len(deltas)): - for j in range(len(deltas[i])): - # for some reason round() doesn't type check with float?? - delta: Any = deltas[i][j] - plt.text( - j + 0.5, - i + 0.5, - round(delta, 3), - ha="center", - va="center", - color="w", - path_effects=[ - patheffects.withStroke(linewidth=2, foreground="#33333370") - ], - ) - - ds_labels = result.dataset_labels - sv_labels = result.steering_labels - - # Add a colorbar to show the scale - # plt.colorbar() - plt.title(f"{title} ({dist_metric}, {delta_type})") - plt.xticks(ticks=np.arange(len(sv_labels)) + 0.5, labels=sv_labels) - plt.yticks(ticks=np.arange(len(ds_labels)) + 0.5, labels=ds_labels) - plt.xlabel("Steering vector") - plt.ylabel("Dataset") - - if save_path is not None: - plt.savefig(save_path, dpi=300) - - # Show the plot - plt.show() diff --git a/tests/experiments/test_persona_generalization.py b/tests/experiments/test_persona_generalization.py index 066cc0b9..7c52b5ef 100644 --- a/tests/experiments/test_persona_generalization.py +++ b/tests/experiments/test_persona_generalization.py @@ -131,10 +131,7 @@ def test_base_dataset_position_is_half_if_evenly_spaced( EvalResult(metrics={"mean_pos_prob": 0.25}, predictions=[]), EvalResult(metrics={"mean_pos_prob": 0.50}, predictions=[]), # baseline ], - pos_steering=[], - neg_steering=[], - pos_multiplier=1.0, - neg_multiplier=-1.0, + steering={}, ), ) assert base_dataset_position(results, dist_metric=dist_metric) == 0.5 @@ -164,10 +161,7 @@ def test_base_dataset_position_is_near_one_if_base_is_near_pos( EvalResult(metrics={"mean_pos_prob": 0.25}, predictions=[]), EvalResult(metrics={"mean_pos_prob": 0.70}, predictions=[]), # baseline ], - pos_steering=[], - neg_steering=[], - pos_multiplier=1.0, - neg_multiplier=-1.0, + steering={}, ), ) assert base_dataset_position(results, dist_metric=dist_metric) > 0.8 @@ -198,10 +192,7 @@ def test_base_dataset_position_is_near_zero_if_base_is_near_neg( EvalResult(metrics={"mean_pos_prob": 0.25}, predictions=[]), EvalResult(metrics={"mean_pos_prob": 0.30}, predictions=[]), # baseline ], - pos_steering=[], - neg_steering=[], - pos_multiplier=1.0, - neg_multiplier=-1.0, + steering={}, ), ) assert base_dataset_position(results, dist_metric=dist_metric) < 0.2 From 2516c2d8bef37e56f966be4fa0a54ceef8aefa41 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Tue, 23 Apr 2024 17:35:24 +0100 Subject: [PATCH 2/3] fixing linting --- repepo/core/pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/repepo/core/pipeline.py b/repepo/core/pipeline.py index a76714cc..0307a35c 100644 --- a/repepo/core/pipeline.py +++ b/repepo/core/pipeline.py @@ -66,7 +66,8 @@ class PipelineContext: class PipelineHook(Protocol): - def __call__(self, context: PipelineContext) -> AbstractContextManager[None]: ... + def __call__(self, context: PipelineContext) -> AbstractContextManager[None]: + ... def compute_moments(tensor: torch.Tensor, dim: int) -> torch.Tensor: From deb5524a44d8671c82d71c481147ef1e9e874b5f Mon Sep 17 00:00:00 2001 From: David Chanin Date: Wed, 24 Apr 2024 15:16:12 +0100 Subject: [PATCH 3/3] fixing typing --- repepo/core/evaluate.py | 4 ++++ repepo/data/translate/translate_mwe.py | 10 +++++++++- repepo/steering/plots/utils.py | 3 +++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/repepo/core/evaluate.py b/repepo/core/evaluate.py index c2bdc916..fc6165ff 100644 --- a/repepo/core/evaluate.py +++ b/repepo/core/evaluate.py @@ -169,6 +169,8 @@ def score_prediction(self, prediction: EvalPrediction) -> float: # the output might be longer than the expected depending on how many tokens we generate # so just verify that the expected output is a prefix of the generated output + assert prediction.positive_output_prob is not None + assert prediction.negative_output_prob is not None positive_output_prob = prediction.positive_output_prob.sum_logprobs negative_output_prob = prediction.negative_output_prob.sum_logprobs return 1.0 if positive_output_prob > negative_output_prob else 0.0 @@ -189,6 +191,8 @@ def score_prediction(self, prediction: EvalPrediction) -> float: """Score a single prediction based on difference in sum of logits.""" # calculate difference in logits + assert prediction.positive_output_prob is not None + assert prediction.negative_output_prob is not None positive_output_logit = prediction.positive_output_prob.sum_logits negative_output_logit = prediction.negative_output_prob.sum_logits return positive_output_logit - negative_output_logit diff --git a/repepo/data/translate/translate_mwe.py b/repepo/data/translate/translate_mwe.py index 9e322eb7..080739b9 100644 --- a/repepo/data/translate/translate_mwe.py +++ b/repepo/data/translate/translate_mwe.py @@ -3,10 +3,18 @@ from typing import Any from repepo.data.make_dataset import get_raw_dataset_dir from repepo.data.utils import translate_row_recursive, collect_all_strings_recursive -from repepo.experiments.constants import MAIN_MWE_PERSONA_DATASETS + from repepo.translation.constants import LangOrStyleCode, LANG_OR_STYLE_MAPPING from repepo.translation.lang_or_style_translate import lang_or_style_translate +MAIN_MWE_PERSONA_DATASETS = [ + "believes-abortion-should-be-illegal", + "desire-for-recursive-self-improvement", + "willingness-to-be-non-HHH-to-be-deployed-in-the-real-world", + "machiavellianism", + "desire-to-persuade-people-to-be-less-harmful-to-others", +] + def translate( dataset: list[Any], diff --git a/repepo/steering/plots/utils.py b/repepo/steering/plots/utils.py index e85f83d8..d60e75c8 100644 --- a/repepo/steering/plots/utils.py +++ b/repepo/steering/plots/utils.py @@ -42,6 +42,8 @@ def make_results_df(results: list[tuple[SteeringConfig, EvalResult]]): row.update(result.metrics) # Sample-wise results for prediction in result.predictions: + assert prediction.positive_output_prob is not None + assert prediction.negative_output_prob is not None sample_row = row.copy() # Add the text output sample_row.update( @@ -53,6 +55,7 @@ def make_results_df(results: list[tuple[SteeringConfig, EvalResult]]): # Add raw metrics sample_row.update(prediction.metrics) # Add other information about logits + token_difference_position = find_token_difference_position( prediction.positive_output_prob, prediction.negative_output_prob )