Skip to content

Commit

Permalink
Polishing ModelPatcher layer guessing and fully replacing WrappedRead…
Browse files Browse the repository at this point in the history
…ingVecModel (#40)

* polishing layer guessing and fully replacing with WrappedReadingVecModel with ModelPatcher

* adding a test for pipeline skipping patching
  • Loading branch information
chanind authored Dec 21, 2023
1 parent 2533e84 commit fb6ccf5
Show file tree
Hide file tree
Showing 10 changed files with 398 additions and 736 deletions.
1 change: 0 additions & 1 deletion repepo/repe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

# RepControl
from .rep_control_pipeline import *
from .rep_control_reading_vec import *

# RepReading
from .rep_readers import *
Expand Down
47 changes: 34 additions & 13 deletions repepo/repe/rep_control_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,61 @@
from typing import Optional, cast
import torch
from transformers.pipelines import TextGenerationPipeline

from .rep_control_reading_vec import WrappedReadingVecModel
from repepo.utils.model_patcher import (
LayerType,
ModelLayerConfig,
ModelPatcher,
PatchOperator,
PatchOperatorName,
)


class RepControlPipeline(TextGenerationPipeline):
"""
This is the RepE RepControlPipeline, but with the WrappedReadingVecModel replaced by ModelPatcher
NOTE: This is just a temporary fix, and we should rewrite our RepE implementation to avoid any unneeded
cruft from the original RepE repo like this class. However, we should do this replacement incrementally
so we can ensure we don't accidentally change any behavior compared with the original implementation.
"""

block_name: LayerType
patch_operator: PatchOperatorName | PatchOperator

def __init__(
self,
model,
tokenizer,
layers,
block_name="decoder_block",
block_name: str = "decoder_block",
control_method="reading_vec",
layer_config: Optional[ModelLayerConfig] = None,
patch_operator: PatchOperatorName | PatchOperator = "addition",
**kwargs,
):
# TODO: implement different control method and supported intermediate modules for different models
assert control_method == "reading_vec", f"{control_method} not supported yet"
assert (
block_name == "decoder_block"
or "LlamaForCausalLM" in model.config.architectures
), f"{model.config.architectures} {block_name} not supported yet"
self.wrapped_model = WrappedReadingVecModel(model, tokenizer)
self.wrapped_model.unwrap()
self.wrapped_model.wrap_block(layers, block_name=block_name)
self.block_name = block_name
self.model_patcher = ModelPatcher(model, layer_config)
self.patch_operator = patch_operator
self.block_name = cast(LayerType, block_name)
self.layers = layers

super().__init__(model=model, tokenizer=tokenizer, **kwargs)

def __call__(self, text_inputs, activations=None, **kwargs):
if activations is not None:
self.wrapped_model.reset()
self.wrapped_model.set_controller(self.layers, activations, self.block_name)
self.model_patcher.remove_patches()
# layers are redundant, just make sure it's not causing confusion
assert len(self.layers) == len(activations)
for layer in self.layers:
assert layer in activations
self.model_patcher.patch_activations(
activations, self.block_name, self.patch_operator
)

with torch.autocast(device_type="cuda"):
outputs = super().__call__(text_inputs, **kwargs)
self.wrapped_model.reset()
self.model_patcher.remove_patches()

return outputs
Loading

0 comments on commit fb6ccf5

Please sign in to comment.