Skip to content

Commit

Permalink
adding a test to assert our steering is identical to CAA steering
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Jan 23, 2024
1 parent aa1dd24 commit f26b2df
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 67 deletions.
8 changes: 7 additions & 1 deletion repepo/algorithms/repe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class SteeringHook:
steering_vector: SteeringVector
direction_multiplier: float
patch_generation_tokens_only: bool
skip_first_n_generation_tokens: int
layer_config: ModelLayerConfig | None

# PipelineContext is created in both `pipeline.generate` or `pipeline.calculate_output_logprobs`,
Expand All @@ -58,11 +59,12 @@ def __call__(self, context: PipelineContext):
try:
min_token_index = 0
if self.patch_generation_tokens_only:
min_token_index = _find_generation_start_token_index(
gen_start_index = _find_generation_start_token_index(
context.pipeline.tokenizer,
context.base_prompt,
context.full_prompt,
)
min_token_index = gen_start_index + self.skip_first_n_generation_tokens
handle = self.steering_vector.patch_activations(
model=context.pipeline.model,
layer_config=self.layer_config,
Expand All @@ -83,6 +85,7 @@ class RepeReadingControl(Algorithm):
direction_multiplier: float
layer_config: ModelLayerConfig | None
patch_generation_tokens_only: bool
skip_first_n_generation_tokens: int
read_token_index: int
seed: int

Expand All @@ -96,6 +99,7 @@ def __init__(
layer_config: Optional[ModelLayerConfig] = None,
direction_multiplier: float = 1.0,
patch_generation_tokens_only: bool = True,
skip_first_n_generation_tokens: int = 0, # only relevant if patch_generation_tokens_only is True
skip_reading: bool = False,
override_vector: Optional[SteeringVector] = None,
skip_control: bool = False,
Expand All @@ -108,6 +112,7 @@ def __init__(
self.layer_type = layer_type
self.seed = seed
self.patch_generation_tokens_only = patch_generation_tokens_only
self.skip_first_n_generation_tokens = skip_first_n_generation_tokens
_validate_reading_template(reading_template)
self.reading_template = reading_template
self.layers = layers
Expand Down Expand Up @@ -186,6 +191,7 @@ def run(self, pipeline: Pipeline, dataset: Dataset) -> Pipeline:
steering_vector=steering_vector,
direction_multiplier=self.direction_multiplier,
patch_generation_tokens_only=self.patch_generation_tokens_only,
skip_first_n_generation_tokens=self.skip_first_n_generation_tokens,
layer_config=self.layer_config,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/_original_caa/llama_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def forward(self, *args, **kwargs):
after=self.after_position,
do_projection=self.do_projection,
)
output = (augmented_output + self.add_activations,) + output[1:]
output = (augmented_output,) + output[1:]

if not self.save_internal_decodings:
return output
Expand Down
133 changes: 68 additions & 65 deletions tests/algorithms/test_repe.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
from repepo.algorithms.repe import (
RepeReadingControl,
SteeringHook,
_find_generation_start_token_index,
)
from repepo.core.format import InputOutputFormatter, LlamaChatFormatter
from repepo.core.types import Dataset, Example, Tokenizer
from repepo.core.pipeline import Pipeline
from repepo.core.pipeline import Pipeline, PipelineContext
from repepo.core.prompt import LlamaChatPrompter
from syrupy import SnapshotAssertion
from transformers import GPTNeoXForCausalLM, LlamaForCausalLM
Expand Down Expand Up @@ -217,70 +218,72 @@ def test_RepeReadingControl_run(
assert original_outputs != new_outputs


# TODO: uncomment this test when https://github.com/nrimsky/SycophancySteering/issues/2 is fixed
# def test_RepeReadingControl_run_steering_matches_caa_llama_wrapper(
# empty_llama_model: LlamaForCausalLM, llama_chat_tokenizer: Tokenizer
# ) -> None:
# model = empty_llama_model
# tokenizer = llama_chat_tokenizer
# pipeline = Pipeline(
# model,
# tokenizer,
# prompter=LlamaChatPrompter(),
# formatter=LlamaChatFormatter(),
# )
# test_example = Example(
# instruction="",
# input="Paris is in",
# output="France",
# incorrect_outputs=["Germany", "Italy"],
# )
# dataset: Dataset = [
# test_example,
# Example(
# instruction="",
# input="1 + 1 =",
# output="2",
# incorrect_outputs=["11", "34", "3.14"],
# ),
# ]

# layers = [0, 1, 2]
# multiplier = 7
# algorithm = RepeReadingControl(
# patch_generation_tokens_only=True,
# direction_multiplier=multiplier,
# layers=layers,
# )
# algorithm.run(pipeline, dataset)
# hook = pipeline.hooks[0]

# # hackily recreating what the pipeline does during logprobs
# base_prompt = pipeline.build_generation_prompt(test_example)
# full_prompt = base_prompt + test_example.output
# inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
# ctx = PipelineContext(
# method="logprobs",
# base_prompt=base_prompt,
# full_prompt=full_prompt,
# inputs=inputs,
# pipeline=pipeline,
# )
# orig_logits = model(**inputs).logits
# with hook(ctx):
# our_logits = model(**inputs).logits

# assert isinstance(hook, SteeringHook) # keep pyright happy
# wrapped_model = LlamaWrapper(model, tokenizer, add_only_after_end_str=True)
# wrapped_model.reset_all()
# for layer in layers:
# wrapped_model.set_add_activations(
# layer, multiplier * hook.steering_vector.layer_activations[layer]
# )
# caa_logits = wrapped_model.get_logits(inputs["input_ids"])
# # only the final answer tokens should be different
# assert torch.allclose(our_logits[0, :-2], orig_logits[0, :-2])
# assert torch.allclose(our_logits, caa_logits)
def test_RepeReadingControl_run_steering_matches_caa_llama_wrapper(
empty_llama_model: LlamaForCausalLM, llama_chat_tokenizer: Tokenizer
) -> None:
model = empty_llama_model
tokenizer = llama_chat_tokenizer
pipeline = Pipeline(
model,
tokenizer,
prompter=LlamaChatPrompter(),
formatter=LlamaChatFormatter(),
)
test_example = Example(
instruction="",
input="Paris is in",
output="France",
incorrect_outputs=["Germany", "Italy"],
)
dataset: Dataset = [
test_example,
Example(
instruction="",
input="1 + 1 =",
output="2",
incorrect_outputs=["11", "34", "3.14"],
),
]

layers = [0, 1, 2]
multiplier = 7
algorithm = RepeReadingControl(
patch_generation_tokens_only=True,
# CAA skips the first generation token for some reason, so we do here too to match
# https://github.com/nrimsky/CAA/issues/3
skip_first_n_generation_tokens=1,
direction_multiplier=multiplier,
layers=layers,
)
algorithm.run(pipeline, dataset)
hook = pipeline.hooks[0]

# hackily recreating what the pipeline does during logprobs
base_prompt = pipeline.build_generation_prompt(test_example)
full_prompt = base_prompt + test_example.output
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
ctx = PipelineContext(
method="logprobs",
base_prompt=base_prompt,
full_prompt=full_prompt,
inputs=inputs,
pipeline=pipeline,
)
orig_logits = model(**inputs).logits
with hook(ctx):
our_logits = model(**inputs).logits

assert isinstance(hook, SteeringHook) # keep pyright happy
wrapped_model = LlamaWrapper(model, tokenizer, add_only_after_end_str=True)
wrapped_model.reset_all()
for layer in layers:
wrapped_model.set_add_activations(
layer, multiplier * hook.steering_vector.layer_activations[layer]
)
caa_logits = wrapped_model.get_logits(inputs["input_ids"])
# only the final answer tokens should be different
assert torch.allclose(our_logits[0, :-2], orig_logits[0, :-2])
assert torch.allclose(our_logits, caa_logits)


def test_RepeReadingControl_run_logprobs_with_patch_generation_tokens_only(
Expand Down

0 comments on commit f26b2df

Please sign in to comment.