Skip to content

Commit

Permalink
fixing tests after CAA merge (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind authored Jan 29, 2024
1 parent cca4b29 commit 6efbfaf
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 34 deletions.
6 changes: 3 additions & 3 deletions repepo/algorithms/repe.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def __init__(
skip_reading: bool = False,
override_vector: Optional[SteeringVector] = None,
skip_control: bool = False,
# For (A/B) datasets, the second last token corresponds to 'A' or 'B'
# which is what the CAA paper extracts.
# For (A/B) datasets, the second last token corresponds to 'A' or 'B' which is what the CAA paper extracts.
# make sure to set this to -2 when working with CAA's data format
# Reference: https://github.com/nrimsky/SycophancySteering/blob/25f93a1f1aad51f94288f52d01f6a10d10f42bf1/generate_vectors.py#L102C13-L102C67
read_token_index: int = -2,
read_token_index: int = -1,
):
self.multi_answer_method = multi_answer_method
self.layer_type = layer_type
Expand Down
1 change: 1 addition & 0 deletions repepo/experiments/caa_repro/prompting_with_steering.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def test_steering(
skip_reading=True,
direction_multiplier=multiplier,
override_vector=steering_vector,
read_token_index=-2,
)
pipeline = Pipeline(model, tokenizer, formatter=LlamaChatFormatter())
# Run algorithm to create the hooks
Expand Down
32 changes: 8 additions & 24 deletions tests/algorithms/__snapshots__/test_repe.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,22 @@
list([
SteeringVectorTrainingSample(
negative_prompt='''
Consider the amount of truthfulness in the following answer.

Input: Paris is in
Output: Germany

The amount of truthfulness in the answer is
<s> Input: Paris is in
Output: Germany
''',
positive_prompt='''
Consider the amount of truthfulness in the following answer.

Input: Paris is in
Output: France

The amount of truthfulness in the answer is
<s> Input: Paris is in
Output: France
''',
),
SteeringVectorTrainingSample(
negative_prompt='''
Consider the amount of truthfulness in the following answer.

Input: 1 + 1 =
Output: 11

The amount of truthfulness in the answer is
<s> Input: 1 + 1 =
Output: 11
''',
positive_prompt='''
Consider the amount of truthfulness in the following answer.

Input: 1 + 1 =
Output: 2

The amount of truthfulness in the answer is
<s> Input: 1 + 1 =
Output: 2
''',
),
])
Expand Down
14 changes: 7 additions & 7 deletions tests/algorithms/test_repe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import pytest
from repepo.algorithms.repe import (
RepeReadingControl,
SteeringHook,
Expand All @@ -13,7 +12,6 @@
from tests._original_caa.llama_wrapper import LlamaWrapper


@pytest.mark.skip("Reading template has changed")
def test_RepeReadingControl_build_steering_vector_training_data_picks_one_neg_by_default(
model: GPTNeoXForCausalLM,
tokenizer: Tokenizer,
Expand Down Expand Up @@ -143,7 +141,12 @@ def test_RepeReadingControl_get_steering_vector_matches_caa(
),
]
layers = [0, 1, 2]
algorithm = RepeReadingControl(multi_answer_method="repeat_correct", layers=layers)
algorithm = RepeReadingControl(
multi_answer_method="repeat_correct",
layers=layers,
# CAA reads from index -2 for steering vectors
read_token_index=-2,
)
steering_vector = algorithm._get_steering_vector(pipeline, dataset)

steering_training_data = algorithm._build_steering_vector_training_data(
Expand Down Expand Up @@ -188,8 +191,7 @@ def test_RepeReadingControl_get_steering_vector_matches_caa(
), f"Non-matching activations at layer {layer}"


@pytest.mark.skip("Reading template has changed")
def test_RepeReadingControl_run(
def test_RepeReadingControl_run_basic(
model: GPTNeoXForCausalLM, tokenizer: Tokenizer
) -> None:
tokenizer.pad_token_id = model.config.eos_token_id
Expand Down Expand Up @@ -223,7 +225,6 @@ def test_RepeReadingControl_run(
assert original_outputs != new_outputs


@pytest.mark.skip("Reading template has changed")
def test_RepeReadingControl_run_steering_matches_caa_llama_wrapper(
empty_llama_model: LlamaForCausalLM, llama_chat_tokenizer: Tokenizer
) -> None:
Expand Down Expand Up @@ -291,7 +292,6 @@ def test_RepeReadingControl_run_steering_matches_caa_llama_wrapper(
assert torch.allclose(our_logits, caa_logits)


@pytest.mark.skip("Reading template has changed")
def test_RepeReadingControl_run_logprobs_with_patch_generation_tokens_only(
model: GPTNeoXForCausalLM, tokenizer: Tokenizer
) -> None:
Expand Down

0 comments on commit 6efbfaf

Please sign in to comment.