-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Polishing ModelPatcher layer guessing and fully replacing WrappedRead…
…ingVecModel (#40) * polishing layer guessing and fully replacing with WrappedReadingVecModel with ModelPatcher * adding a test for pipeline skipping patching
- Loading branch information
Showing
10 changed files
with
398 additions
and
736 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,61 @@ | ||
from typing import Optional, cast | ||
import torch | ||
from transformers.pipelines import TextGenerationPipeline | ||
|
||
from .rep_control_reading_vec import WrappedReadingVecModel | ||
from repepo.utils.model_patcher import ( | ||
LayerType, | ||
ModelLayerConfig, | ||
ModelPatcher, | ||
PatchOperator, | ||
PatchOperatorName, | ||
) | ||
|
||
|
||
class RepControlPipeline(TextGenerationPipeline): | ||
""" | ||
This is the RepE RepControlPipeline, but with the WrappedReadingVecModel replaced by ModelPatcher | ||
NOTE: This is just a temporary fix, and we should rewrite our RepE implementation to avoid any unneeded | ||
cruft from the original RepE repo like this class. However, we should do this replacement incrementally | ||
so we can ensure we don't accidentally change any behavior compared with the original implementation. | ||
""" | ||
|
||
block_name: LayerType | ||
patch_operator: PatchOperatorName | PatchOperator | ||
|
||
def __init__( | ||
self, | ||
model, | ||
tokenizer, | ||
layers, | ||
block_name="decoder_block", | ||
block_name: str = "decoder_block", | ||
control_method="reading_vec", | ||
layer_config: Optional[ModelLayerConfig] = None, | ||
patch_operator: PatchOperatorName | PatchOperator = "addition", | ||
**kwargs, | ||
): | ||
# TODO: implement different control method and supported intermediate modules for different models | ||
assert control_method == "reading_vec", f"{control_method} not supported yet" | ||
assert ( | ||
block_name == "decoder_block" | ||
or "LlamaForCausalLM" in model.config.architectures | ||
), f"{model.config.architectures} {block_name} not supported yet" | ||
self.wrapped_model = WrappedReadingVecModel(model, tokenizer) | ||
self.wrapped_model.unwrap() | ||
self.wrapped_model.wrap_block(layers, block_name=block_name) | ||
self.block_name = block_name | ||
self.model_patcher = ModelPatcher(model, layer_config) | ||
self.patch_operator = patch_operator | ||
self.block_name = cast(LayerType, block_name) | ||
self.layers = layers | ||
|
||
super().__init__(model=model, tokenizer=tokenizer, **kwargs) | ||
|
||
def __call__(self, text_inputs, activations=None, **kwargs): | ||
if activations is not None: | ||
self.wrapped_model.reset() | ||
self.wrapped_model.set_controller(self.layers, activations, self.block_name) | ||
self.model_patcher.remove_patches() | ||
# layers are redundant, just make sure it's not causing confusion | ||
assert len(self.layers) == len(activations) | ||
for layer in self.layers: | ||
assert layer in activations | ||
self.model_patcher.patch_activations( | ||
activations, self.block_name, self.patch_operator | ||
) | ||
|
||
with torch.autocast(device_type="cuda"): | ||
outputs = super().__call__(text_inputs, **kwargs) | ||
self.wrapped_model.reset() | ||
self.model_patcher.remove_patches() | ||
|
||
return outputs |
Oops, something went wrong.