-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CAA #66
CAA #66
Conversation
* Add CAA datasets * Update makefile * Add test for make_ab_prompt --------- Co-authored-by: dtch1997 <dtch1997@users.noreply.github.com>
layer_config=self.layer_config, | ||
# NOTE: if the direction multiplier is changed, | ||
# subsequent generations will use the new value | ||
# because this is a reference to the outer scope. | ||
# This is probably counterintuitive | ||
# NOTE: Same goes for layer_config above, | ||
# but this is less critical because layer config is likely static | ||
# TODO: change at some point. | ||
multiplier=self.direction_multiplier, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This behaviour is highly unintuitive, as the hooks are stored in pipeline
but they still read the state from the RepeReadingControl
algorithm after .run
terminates.
We should refactor this before merging.
Generally, we should try to ensure all relevant state that the hooks will reference, is encapsulated within the The focus should be on making it easy to modify:
|
repepo/algorithms/repe.py
Outdated
# Steering vector reading | ||
# NOTE: The hooks read from this steering vector. | ||
steering_vector = self._get_steering_vector(pipeline, dataset) | ||
|
||
# Creating the hooks that will do steering vector control | ||
# NOTE: How this works is that we create a context manager that creates a hook | ||
# whenever we are in a `PipelineContext`'s scope. | ||
# After exiting the context, the hook is deleted. | ||
|
||
# The PipelineContext is created in both `pipeline.generate` or `pipeline.calculate_output_logprobs` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chanind could you comment on whether I've described the logic here accurately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not correct that the hook is deleted after exiting the context, but it could be a confusion between the Pipeline hook and the Pytorch hook. The pipeline hook is just in an array on the pipeline, and stays there until it's removed. The hook only gets applied to the model during pipeline.generate
or pipeline.calculate_output_logprobs
.
This PR removes all REPE code and instead replaces it with CAA-style steering vectors, where the steering vectors are found by simply subtracting pos - neg and then taking the mean.
This PR is large because it removes the old Repe stuff, and also moves some of the existing code into a
steering_vectors
module. This PR introduces the following ideas:Steering Vectors
The
steering_vectors
module is separated out from the rest of the code, since this can be published as its own library. This consists of 2 main components for the public API:train_steering_vector()
andSteeringVector
. Thetrain_steering_vector()
function takes a list of paired pos and neg prompts, and returns a steering vector instance. The steering vector can be used to steer generation in a LLM.Basic usage:
There are a number of improvements we can make to this in the future, such as:
That being said, it's probably already publishable as a standalone Python library
Pipeline hooks
Since CAA requires that we only patch activations after the prompt, we need a way for us to tell the steering vector which token in the given prompt should be patched. Our current implementation of
Pipeline
doesn't have a way to feed this information into the steering vector, so to get around this, this PR adds a concept ofhook
in thePipeline
class. These hooks take in acontext
object which contains info about what the pipeline is doing (which example is being parsed, what's the base prompt text, what's the full prompt text, etc...), and then wraps the generation/logprobs calculation. This way, repe can have enough information about what the pipeline is currently running in order to correctly patch activations.