Skip to content

Commit

Permalink
cleaning up oddities with steering vecs and repe algo (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind authored Jan 18, 2024
1 parent 1487639 commit 773db50
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 39 deletions.
85 changes: 54 additions & 31 deletions repepo/algorithms/repe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from dataclasses import replace
from dataclasses import dataclass, replace
from typing import Literal, Optional
from typing_extensions import override
import random
Expand Down Expand Up @@ -36,6 +36,45 @@ def _validate_reading_template(reading_template: str):
)


@dataclass
class SteeringHook:
"""
Pipeline hook that applies a steering vector to the model.
All relevant state for the hook is stored in this class.
If the params included in this class are changed, it will
affect future generation and logprob calls using this hook.
"""

steering_vector: SteeringVector
direction_multiplier: float
patch_generation_tokens_only: bool
layer_config: ModelLayerConfig | None

# PipelineContext is created in both `pipeline.generate` or `pipeline.calculate_output_logprobs`,
# It also contains info about the current prompt which is used to determine which tokens to patch.
@contextmanager
def __call__(self, context: PipelineContext):
handle = None
try:
min_token_index = 0
if self.patch_generation_tokens_only:
min_token_index = _find_generation_start_token_index(
context.pipeline.tokenizer,
context.base_prompt,
context.full_prompt,
)
handle = self.steering_vector.patch_activations(
model=context.pipeline.model,
layer_config=self.layer_config,
multiplier=self.direction_multiplier,
min_token_index=min_token_index,
)
yield
finally:
if handle is not None:
handle.remove()


class RepeReadingControl(Algorithm):
layer_type: LayerType
multi_answer_method: MultiAnswerMethod
Expand All @@ -44,6 +83,7 @@ class RepeReadingControl(Algorithm):
direction_multiplier: float
layer_config: ModelLayerConfig | None
patch_generation_tokens_only: bool
read_token_index: int
seed: int

def __init__(
Expand All @@ -59,6 +99,10 @@ def __init__(
skip_reading: bool = False,
override_vector: Optional[SteeringVector] = None,
skip_control: bool = False,
# For (A/B) datasets, the second last token corresponds to 'A' or 'B'
# which is what the CAA paper extracts.
# Reference: https://github.com/nrimsky/SycophancySteering/blob/25f93a1f1aad51f94288f52d01f6a10d10f42bf1/generate_vectors.py#L102C13-L102C67
read_token_index: int = -2,
):
self.multi_answer_method = multi_answer_method
self.layer_type = layer_type
Expand All @@ -67,6 +111,7 @@ def __init__(
_validate_reading_template(reading_template)
self.reading_template = reading_template
self.layers = layers
self.read_token_index = read_token_index
self.layer_config = layer_config
self.direction_multiplier = direction_multiplier

Expand Down Expand Up @@ -110,6 +155,8 @@ def _get_steering_vector(
layers=self.layers,
layer_type=self.layer_type,
layer_config=self.layer_config,
move_to_cpu=True,
read_token_index=self.read_token_index,
)

@override
Expand All @@ -133,38 +180,14 @@ def run(self, pipeline: Pipeline, dataset: Dataset) -> Pipeline:
# whenever we are in a `PipelineContext`'s scope.
# After exiting the context, the hook is deleted.

# The PipelineContext is created in both `pipeline.generate` or `pipeline.calculate_output_logprobs`

# need to use a hook so we can inspect the current thing being generated to know
# which tokens to patch
@contextmanager
def steering_hook(context: PipelineContext):
handle = None
try:
min_token_index = 0
if self.patch_generation_tokens_only:
min_token_index = _find_generation_start_token_index(
pipeline.tokenizer,
context.base_prompt,
context.full_prompt,
)
handle = steering_vector.patch_activations(
model=pipeline.model,
layer_config=self.layer_config,
# NOTE: if the direction multiplier is changed,
# subsequent generations will use the new value
# because this is a reference to the outer scope.
# This is probably counterintuitive
# NOTE: Same goes for layer_config above,
# but this is less critical because layer config is likely static
# TODO: change at some point.
multiplier=self.direction_multiplier,
min_token_index=min_token_index,
)
yield
finally:
if handle is not None:
handle.remove()
steering_hook = SteeringHook(
steering_vector=steering_vector,
direction_multiplier=self.direction_multiplier,
patch_generation_tokens_only=self.patch_generation_tokens_only,
layer_config=self.layer_config,
)

if not self.skip_control:
pipeline.hooks.append(steering_hook)
Expand Down
3 changes: 3 additions & 0 deletions repepo/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class PipelineContext:
base_prompt: str
full_prompt: str
inputs: Any
pipeline: "Pipeline"


PipelineHook = Callable[[PipelineContext], AbstractContextManager]
Expand Down Expand Up @@ -72,6 +73,7 @@ def generate(
base_prompt=base_prompt,
full_prompt=base_prompt,
inputs=inputs,
pipeline=self,
)
with ExitStack() as stack:
for hook in self.hooks:
Expand All @@ -97,6 +99,7 @@ def calculate_output_logprobs(self, example: Example) -> TextProbs:
base_prompt=base_prompt,
full_prompt=full_prompt,
inputs=inputs,
pipeline=self,
)
with ExitStack() as stack:
for hook in self.hooks:
Expand Down
13 changes: 5 additions & 8 deletions steering_vectors/train_steering_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def train_steering_vector(
layer_type: LayerType = "decoder_block",
layer_config: Optional[ModelLayerConfig] = None,
move_to_cpu: bool = False,
read_token_index: int = -1,
# TODO: add more options to control training
) -> SteeringVector:
layer_config = guess_and_enhance_layer_config(model, layer_config, layer_type)
Expand All @@ -37,6 +38,7 @@ def train_steering_vector(
layer_type=layer_type,
layer_config=layer_config,
layers=layers,
read_token_index=read_token_index,
)
neg_acts = _extract_activations(
model,
Expand All @@ -45,6 +47,7 @@ def train_steering_vector(
layer_type=layer_type,
layer_config=layer_config,
layers=layers,
read_token_index=read_token_index,
)
for layer_num, pos_act in pos_acts.items():
if move_to_cpu:
Expand All @@ -71,6 +74,7 @@ def _extract_activations(
layer_type: LayerType,
layer_config: ModelLayerConfig,
layers: list[int] | None,
read_token_index: int,
) -> dict[int, Tensor]:
input = tokenizer(prompt, return_tensors="pt").to(model.device)
results = {}
Expand All @@ -79,12 +83,5 @@ def _extract_activations(
) as record:
model(**input)
for layer_num, activation in record.items():
# NOTE: This is hardcoded to extract the second-last token activtion only
# For (A/B) datasets, the second last token corresponds to 'A' or 'B'
# which is what the CAA paper extracts.
# Reference: https://github.com/nrimsky/SycophancySteering/blob/25f93a1f1aad51f94288f52d01f6a10d10f42bf1/generate_vectors.py#L102C13-L102C67

# TODO: allow controlling which token(s) to extract

results[layer_num] = activation[-1][0, -2].detach()
results[layer_num] = activation[-1][0, read_token_index].detach()
return results

0 comments on commit 773db50

Please sign in to comment.