From aa1dd24966d9e1d696782b56dafda779ad99bd30 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Mon, 22 Jan 2024 13:05:59 -0500 Subject: [PATCH] Verify our code matches CAA (#76) * adding a llama chat formater and prompter based on CAA * testing that our reading vectors match CAA reading vectors * fixing linting * fixing test --- repepo/core/format.py | 22 +++ repepo/core/prompt.py | 26 ++++ tests/_original_caa/helpers.py | 161 ++++++++++++++++++++ tests/_original_caa/llama_wrapper.py | 215 +++++++++++++++++++++++++++ tests/algorithms/test_repe.py | 136 ++++++++++++++++- tests/conftest.py | 5 + tests/core/test_formatter.py | 19 ++- tests/core/test_prompter.py | 12 +- 8 files changed, 588 insertions(+), 8 deletions(-) create mode 100644 tests/_original_caa/helpers.py create mode 100644 tests/_original_caa/llama_wrapper.py diff --git a/repepo/core/format.py b/repepo/core/format.py index c788bbac..d3dd3bc1 100644 --- a/repepo/core/format.py +++ b/repepo/core/format.py @@ -37,6 +37,28 @@ def apply(self, example: Example, **kwargs): ) +class LlamaChatFormatter(Formatter): + """ + Add [INST] and [/INST] tags to the instruction and input. + + Based on: https://github.com/nrimsky/SycophancySteering/blob/main/utils/tokenize_llama.py#L30 + """ + + B_INST = "[INST]" + E_INST = "[/INST]" + + @override + def apply(self, example: Example): + dialog_content_parts = [] + if example.instruction: + dialog_content_parts.append(example.instruction.strip()) + dialog_content_parts.append(example.input.strip()) + dialog_content = "\n".join(dialog_content_parts) + prompt = f"{self.B_INST} {dialog_content} {self.E_INST} " + response = example.output.strip() + return Completion(prompt=prompt, response=response) + + class InstructionFormatter(Formatter): """Instruction formatter used for fine-tuning Alpaca.""" diff --git a/repepo/core/prompt.py b/repepo/core/prompt.py index 019f1793..5ece58da 100644 --- a/repepo/core/prompt.py +++ b/repepo/core/prompt.py @@ -1,4 +1,5 @@ import abc +from dataclasses import dataclass from typing import List from typing_extensions import override @@ -28,6 +29,31 @@ def apply(self, completion): return completion +@dataclass +class LlamaChatPrompter(Prompter): + """ + Prepend system message before the first prompt. This is based on the SycophancySteering + prompt format, except the system message is not inside of the [INST] tags. Hopefully + this is OK? + + TODO: Put the sys message inside of the [INST] tags, or at least make it possible to do so. + This will require some rearchitecting + + Based on: https://github.com/nrimsky/SycophancySteering/blob/main/utils/tokenize_llama.py#L30 + """ + + system_prompt: str = "You are a helpful, honest and concise assistant." + + B_SYS = "<>\n" + E_SYS = "\n<>\n\n" + + @override + def apply(self, completion: Completion) -> Completion: + prompt = f"{self.B_SYS}{self.system_prompt}{self.E_SYS}{completion.prompt}" + response = completion.response + return Completion(prompt=prompt, response=response) + + class FewShotPrompter(Prompter): """Compose examples few-shot""" diff --git a/tests/_original_caa/helpers.py b/tests/_original_caa/helpers.py new file mode 100644 index 00000000..d6327119 --- /dev/null +++ b/tests/_original_caa/helpers.py @@ -0,0 +1,161 @@ +import torch as t +from typing import Optional +import os +from dataclasses import dataclass + + +@dataclass +class SteeringSettings: + """ + max_new_tokens: Maximum number of tokens to generate. + type: Type of test to run. One of "in_distribution", "out_of_distribution", "truthful_qa". + few_shot: Whether to test with few-shot examples in the prompt. One of "positive", "negative", "none". + do_projection: Whether to project activations onto orthogonal complement of steering vector. + override_vector: If not None, the steering vector generated from this layer's activations will be used at all layers. Use to test the effect of steering with a different layer's vector. + override_vector_model: If not None, steering vectors generated from this model will be used instead of the model being used for generation - use to test vector transference between models. + use_base_model: Whether to use the base model instead of the chat model. + model_size: Size of the model to use. One of "7b", "13b". + n_test_datapoints: Number of datapoints to test on. If None, all datapoints will be used. + add_every_token_position: Whether to add steering vector to every token position (including question), not only to token positions corresponding to the model's response to the user + override_model_weights_path: If not None, the model weights at this path will be used instead of the model being used for generation - use to test using activation steering on top of fine-tuned model. + """ + + max_new_tokens: int = 100 + type: str = "in_distribution" + few_shot: str = "none" + do_projection: bool = False + override_vector: Optional[int] = None + override_vector_model: Optional[str] = None + use_base_model: bool = False + model_size: str = "7b" + n_test_datapoints: Optional[int] = None + add_every_token_position: bool = False + override_model_weights_path: Optional[str] = None + + def make_result_save_suffix( + self, + layer: Optional[int] = None, + multiplier: Optional[int] = None, + ): + elements = { + "layer": layer, + "multiplier": multiplier, + "max_new_tokens": self.max_new_tokens, + "type": self.type, + "few_shot": self.few_shot, + "do_projection": self.do_projection, + "override_vector": self.override_vector, + "override_vector_model": self.override_vector_model, + "use_base_model": self.use_base_model, + "model_size": self.model_size, + "n_test_datapoints": self.n_test_datapoints, + "add_every_token_position": self.add_every_token_position, + "override_model_weights_path": self.override_model_weights_path, + } + return "_".join( + [ + f"{k}={str(v).replace('/', '-')}" + for k, v in elements.items() + if v is not None + ] + ) + + def filter_result_files_by_suffix( + self, + directory: str, + layer: Optional[int] = None, + multiplier: Optional[int] = None, + ): + elements = { + "layer": layer, + "multiplier": multiplier, + "max_new_tokens": self.max_new_tokens, + "type": self.type, + "few_shot": self.few_shot, + "do_projection": self.do_projection, + "override_vector": self.override_vector, + "override_vector_model": self.override_vector_model, + "use_base_model": self.use_base_model, + "model_size": self.model_size, + "n_test_datapoints": self.n_test_datapoints, + "add_every_token_position": self.add_every_token_position, + "override_model_weights_path": self.override_model_weights_path, + } + + filtered_elements = {k: v for k, v in elements.items() if v is not None} + remove_elements = {k for k, v in elements.items() if v is None} + + matching_files = [] + + print(self.override_model_weights_path) + + for filename in os.listdir(directory): + if all( + f"{k}={str(v).replace('/', '-')}" in filename + for k, v in filtered_elements.items() + ): + # ensure remove_elements are *not* present + if all(f"{k}=" not in filename for k in remove_elements): + matching_files.append(filename) + + return [os.path.join(directory, f) for f in matching_files] + + +def project_onto_orthogonal_complement(tensor, onto): + """ + Projects tensor onto the orthogonal complement of the span of onto. + """ + # Get the projection of tensor onto onto + proj = ( + t.sum(tensor * onto, dim=-1, keepdim=True) + * onto + / (t.norm(onto, dim=-1, keepdim=True) ** 2 + 1e-10) + ) + # Subtract to get the orthogonal component + return tensor - proj + + +def add_vector_after_position( + matrix, vector, position_ids, after=None, do_projection=True +): + after_id = after + if after_id is None: + after_id = position_ids.min().item() - 1 + + mask = position_ids > after_id + mask = mask.unsqueeze(-1) + + if do_projection: + matrix = project_onto_orthogonal_complement(matrix, vector) + + matrix += mask.float() * vector + return matrix + + +def find_last_subtensor_position(tensor, sub_tensor): + n, m = tensor.size(0), sub_tensor.size(0) + if m > n: + return -1 + for i in range(n - m, -1, -1): + if t.equal(tensor[i : i + m], sub_tensor): + return i + return -1 + + +def find_instruction_end_postion(tokens, end_str): + start_pos = find_last_subtensor_position(tokens, end_str) + if start_pos == -1: + return -1 + return start_pos + len(end_str) - 1 + + +def get_a_b_probs(logits, a_token_id, b_token_id): + last_token_logits = logits[0, -1, :] + last_token_probs = t.softmax(last_token_logits, dim=-1) + a_prob = last_token_probs[a_token_id].item() + b_prob = last_token_probs[b_token_id].item() + return a_prob, b_prob + + +def make_tensor_save_suffix(layer, model_name_path): + return f'{layer}_{model_name_path.split("/")[-1]}' diff --git a/tests/_original_caa/llama_wrapper.py b/tests/_original_caa/llama_wrapper.py new file mode 100644 index 00000000..c0aff937 --- /dev/null +++ b/tests/_original_caa/llama_wrapper.py @@ -0,0 +1,215 @@ +import torch as t +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from .helpers import add_vector_after_position, find_instruction_end_postion + + +class AttnWrapper(t.nn.Module): + """ + Wrapper for attention mechanism to save activations + """ + + def __init__(self, attn): + super().__init__() + self.attn = attn + self.activations = None + + def forward(self, *args, **kwargs): + output = self.attn(*args, **kwargs) + self.activations = output[0] + return output + + +class BlockOutputWrapper(t.nn.Module): + """ + Wrapper for block to save activations and unembed them + """ + + def __init__(self, block, unembed_matrix, norm, tokenizer): + super().__init__() + self.block = block + self.unembed_matrix = unembed_matrix + self.norm = norm + self.tokenizer = tokenizer + + self.block.self_attn = AttnWrapper(self.block.self_attn) + self.post_attention_layernorm = self.block.post_attention_layernorm + + self.attn_out_unembedded = None + self.intermediate_resid_unembedded = None + self.mlp_out_unembedded = None + self.block_out_unembedded = None + + self.activations = None + self.add_activations = None + self.after_position = None + + self.save_internal_decodings = False + self.do_projection = False + + self.calc_dot_product_with = None + self.dot_products = [] + + def forward(self, *args, **kwargs): + output = self.block(*args, **kwargs) + self.activations = output[0] + if self.calc_dot_product_with is not None: + last_token_activations = self.activations[0, -1, :] + decoded_activations = self.unembed_matrix(self.norm(last_token_activations)) + top_token_id = t.topk(decoded_activations, 1)[1][0] + top_token = self.tokenizer.decode(top_token_id) + dot_product = t.dot(last_token_activations, self.calc_dot_product_with) / ( + t.norm(last_token_activations) * t.norm(self.calc_dot_product_with) + ) + self.dot_products.append((top_token, dot_product.cpu().item())) + if self.add_activations is not None: + augmented_output = add_vector_after_position( + matrix=output[0], + vector=self.add_activations, + position_ids=kwargs["position_ids"], + after=self.after_position, + do_projection=self.do_projection, + ) + output = (augmented_output + self.add_activations,) + output[1:] + + if not self.save_internal_decodings: + return output + + # Whole block unembedded + self.block_output_unembedded = self.unembed_matrix(self.norm(output[0])) + + # Self-attention unembedded + attn_output = self.block.self_attn.activations + self.attn_out_unembedded = self.unembed_matrix(self.norm(attn_output)) + + # Intermediate residual unembedded + attn_output += args[0] + self.intermediate_resid_unembedded = self.unembed_matrix(self.norm(attn_output)) + + # MLP unembedded + mlp_output = self.block.mlp(self.post_attention_layernorm(attn_output)) + self.mlp_out_unembedded = self.unembed_matrix(self.norm(mlp_output)) + + return output + + def add(self, activations, do_projection=False): + self.add_activations = activations + self.do_projection = do_projection + + def reset(self): + self.add_activations = None + self.activations = None + self.block.self_attn.activations = None + self.after_position = None + self.do_projection = False + self.calc_dot_product_with = None + self.dot_products = [] + + +# Modified so we can test with our own model instances without having to loa +# a massive 7b model from the internet in a testcase +class LlamaWrapper: + def __init__( + self, + model: t.nn.Module, + tokenizer: PreTrainedTokenizerBase, + add_only_after_end_str=False, + ): + self.device = "cuda" if t.cuda.is_available() else "cpu" + self.add_only_after_end_str = add_only_after_end_str + self.model = model + self.tokenizer = tokenizer + self.model = self.model.to(self.device) + self.END_STR = t.tensor(self.tokenizer.encode("[/INST]")[1:]).to(self.device) + for i, layer in enumerate(self.model.model.layers): + self.model.model.layers[i] = BlockOutputWrapper( + layer, self.model.lm_head, self.model.model.norm, self.tokenizer + ) + + def set_save_internal_decodings(self, value: bool): + for layer in self.model.model.layers: + layer.save_internal_decodings = value + + def set_after_positions(self, pos: int): + for layer in self.model.model.layers: + layer.after_position = pos + + def generate(self, tokens, max_new_tokens=50): + with t.no_grad(): + if self.add_only_after_end_str: + instr_pos = find_instruction_end_postion(tokens[0], self.END_STR) + else: + instr_pos = None + self.set_after_positions(instr_pos) + generated = self.model.generate( + inputs=tokens, max_new_tokens=max_new_tokens, top_k=1 + ) + return self.tokenizer.batch_decode(generated)[0] + + def get_logits(self, tokens): + with t.no_grad(): + if self.add_only_after_end_str: + instr_pos = find_instruction_end_postion(tokens[0], self.END_STR) + else: + instr_pos = None + self.set_after_positions(instr_pos) + logits = self.model(tokens).logits + return logits + + def get_last_activations(self, layer): + return self.model.model.layers[layer].activations + + def set_add_activations(self, layer, activations, do_projection=False): + self.model.model.layers[layer].add(activations, do_projection) + + def set_calc_dot_product_with(self, layer, vector): + self.model.model.layers[layer].calc_dot_product_with = vector + + def get_dot_products(self, layer): + return self.model.model.layers[layer].dot_products + + def reset_all(self): + for layer in self.model.model.layers: + layer.reset() + + def print_decoded_activations(self, decoded_activations, label, topk=10): + data = self.get_activation_data(decoded_activations, topk)[0] + print(label, data) + + def decode_all_layers( + self, + tokens, + topk=10, + print_attn_mech=True, + print_intermediate_res=True, + print_mlp=True, + print_block=True, + ): + tokens = tokens.to(self.device) + self.get_logits(tokens) + for i, layer in enumerate(self.model.model.layers): + print(f"Layer {i}: Decoded intermediate outputs") + if print_attn_mech: + self.print_decoded_activations( + layer.attn_out_unembedded, "Attention mechanism", topk=topk + ) + if print_intermediate_res: + self.print_decoded_activations( + layer.intermediate_resid_unembedded, + "Intermediate residual stream", + topk=topk, + ) + if print_mlp: + self.print_decoded_activations( + layer.mlp_out_unembedded, "MLP output", topk=topk + ) + if print_block: + self.print_decoded_activations( + layer.block_output_unembedded, "Block output", topk=topk + ) + + def get_activation_data(self, decoded_activations, topk=10): + softmaxed = t.nn.functional.softmax(decoded_activations[0][-1], dim=-1) + values, indices = t.topk(softmaxed, topk) + probs_percent = [int(v * 100) for v in values.tolist()] + tokens = self.tokenizer.batch_decode(indices.unsqueeze(-1)) + return list(zip(tokens, probs_percent)), list(zip(tokens, values.tolist())) diff --git a/tests/algorithms/test_repe.py b/tests/algorithms/test_repe.py index ae44bf01..2aadeb7c 100644 --- a/tests/algorithms/test_repe.py +++ b/tests/algorithms/test_repe.py @@ -1,12 +1,15 @@ +import torch from repepo.algorithms.repe import ( RepeReadingControl, _find_generation_start_token_index, ) -from repepo.core.format import InputOutputFormatter +from repepo.core.format import InputOutputFormatter, LlamaChatFormatter from repepo.core.types import Dataset, Example, Tokenizer from repepo.core.pipeline import Pipeline +from repepo.core.prompt import LlamaChatPrompter from syrupy import SnapshotAssertion -from transformers import GPTNeoXForCausalLM +from transformers import GPTNeoXForCausalLM, LlamaForCausalLM +from tests._original_caa.llama_wrapper import LlamaWrapper def test_RepeReadingControl_build_steering_vector_training_data_picks_one_neg_by_default( @@ -117,6 +120,69 @@ def test_RepeReadingControl_get_steering_vector( assert act.shape == (512,) +def test_RepeReadingControl_get_steering_vector_matches_caa( + empty_llama_model: LlamaForCausalLM, llama_chat_tokenizer: Tokenizer +) -> None: + model = empty_llama_model + tokenizer = llama_chat_tokenizer + pipeline = Pipeline( + model, + tokenizer, + prompter=LlamaChatPrompter(), + formatter=LlamaChatFormatter(), + ) + dataset: Dataset = [ + Example( + instruction="", + input="Paris is in", + output="France", + incorrect_outputs=["Germany", "Italy"], + ), + ] + layers = [0, 1, 2] + algorithm = RepeReadingControl(multi_answer_method="repeat_correct", layers=layers) + steering_vector = algorithm._get_steering_vector(pipeline, dataset) + + steering_training_data = algorithm._build_steering_vector_training_data( + dataset, pipeline.formatter + ) + # hackily translated from generate_vectors.py script + tokenized_data = [ + (tokenizer.encode(pos), tokenizer.encode(neg)) + for pos, neg in steering_training_data + ] + pos_activations = dict([(layer, []) for layer in layers]) + neg_activations = dict([(layer, []) for layer in layers]) + wrapped_model = LlamaWrapper(model, tokenizer) + + for p_tokens, n_tokens in tokenized_data: + p_tokens = torch.tensor(p_tokens).unsqueeze(0).to(model.device) + n_tokens = torch.tensor(n_tokens).unsqueeze(0).to(model.device) + wrapped_model.reset_all() + wrapped_model.get_logits(p_tokens) + for layer in layers: + p_activations = wrapped_model.get_last_activations(layer) + p_activations = p_activations[0, -2, :].detach().cpu() + pos_activations[layer].append(p_activations) + wrapped_model.reset_all() + wrapped_model.get_logits(n_tokens) + for layer in layers: + n_activations = wrapped_model.get_last_activations(layer) + n_activations = n_activations[0, -2, :].detach().cpu() + neg_activations[layer].append(n_activations) + + caa_vecs_by_layer = {} + for layer in layers: + all_pos_layer = torch.stack(pos_activations[layer]) + all_neg_layer = torch.stack(neg_activations[layer]) + caa_vecs_by_layer[layer] = (all_pos_layer - all_neg_layer).mean(dim=0) + + for layer in layers: + assert torch.allclose( + steering_vector.layer_activations[layer], caa_vecs_by_layer[layer] + ) + + def test_RepeReadingControl_run( model: GPTNeoXForCausalLM, tokenizer: Tokenizer ) -> None: @@ -151,6 +217,72 @@ def test_RepeReadingControl_run( assert original_outputs != new_outputs +# TODO: uncomment this test when https://github.com/nrimsky/SycophancySteering/issues/2 is fixed +# def test_RepeReadingControl_run_steering_matches_caa_llama_wrapper( +# empty_llama_model: LlamaForCausalLM, llama_chat_tokenizer: Tokenizer +# ) -> None: +# model = empty_llama_model +# tokenizer = llama_chat_tokenizer +# pipeline = Pipeline( +# model, +# tokenizer, +# prompter=LlamaChatPrompter(), +# formatter=LlamaChatFormatter(), +# ) +# test_example = Example( +# instruction="", +# input="Paris is in", +# output="France", +# incorrect_outputs=["Germany", "Italy"], +# ) +# dataset: Dataset = [ +# test_example, +# Example( +# instruction="", +# input="1 + 1 =", +# output="2", +# incorrect_outputs=["11", "34", "3.14"], +# ), +# ] + +# layers = [0, 1, 2] +# multiplier = 7 +# algorithm = RepeReadingControl( +# patch_generation_tokens_only=True, +# direction_multiplier=multiplier, +# layers=layers, +# ) +# algorithm.run(pipeline, dataset) +# hook = pipeline.hooks[0] + +# # hackily recreating what the pipeline does during logprobs +# base_prompt = pipeline.build_generation_prompt(test_example) +# full_prompt = base_prompt + test_example.output +# inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) +# ctx = PipelineContext( +# method="logprobs", +# base_prompt=base_prompt, +# full_prompt=full_prompt, +# inputs=inputs, +# pipeline=pipeline, +# ) +# orig_logits = model(**inputs).logits +# with hook(ctx): +# our_logits = model(**inputs).logits + +# assert isinstance(hook, SteeringHook) # keep pyright happy +# wrapped_model = LlamaWrapper(model, tokenizer, add_only_after_end_str=True) +# wrapped_model.reset_all() +# for layer in layers: +# wrapped_model.set_add_activations( +# layer, multiplier * hook.steering_vector.layer_activations[layer] +# ) +# caa_logits = wrapped_model.get_logits(inputs["input_ids"]) +# # only the final answer tokens should be different +# assert torch.allclose(our_logits[0, :-2], orig_logits[0, :-2]) +# assert torch.allclose(our_logits, caa_logits) + + def test_RepeReadingControl_run_logprobs_with_patch_generation_tokens_only( model: GPTNeoXForCausalLM, tokenizer: Tokenizer ) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index df3110a0..7da379b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,3 +73,8 @@ def empty_llama_model() -> LlamaForCausalLM: intermediate_size=2752, ) return LlamaForCausalLM(config) + + +@pytest.fixture +def llama_chat_tokenizer() -> Tokenizer: + return AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") diff --git a/tests/core/test_formatter.py b/tests/core/test_formatter.py index 345f9037..39fe0b1c 100644 --- a/tests/core/test_formatter.py +++ b/tests/core/test_formatter.py @@ -1,7 +1,9 @@ -from repepo.core.format import InputOutputFormatter -from repepo.core.format import InstructionFormatter -from repepo.core.types import Completion -from repepo.core.types import Example +from repepo.core.format import ( + InputOutputFormatter, + InstructionFormatter, + LlamaChatFormatter, +) +from repepo.core.types import Completion, Example def test_input_output_formatter(): @@ -23,3 +25,12 @@ def test_instruction_formatter(): ) assert completion.prompt == expected_prompt assert completion.response == "3" + + +def test_llama_chat_formatter(): + example = Example(instruction="add numbers", input="1, 2", output="3") + formatter = LlamaChatFormatter() + completion = formatter.apply(example) + expected_prompt = "[INST] add numbers\n1, 2 [/INST] " + assert completion.prompt == expected_prompt + assert completion.response == "3" diff --git a/tests/core/test_prompter.py b/tests/core/test_prompter.py index c142df59..5a874b21 100644 --- a/tests/core/test_prompter.py +++ b/tests/core/test_prompter.py @@ -1,5 +1,4 @@ -from repepo.core.prompt import FewShotPrompter -from repepo.core.prompt import IdentityPrompter +from repepo.core.prompt import FewShotPrompter, IdentityPrompter, LlamaChatPrompter from repepo.core.types import Completion @@ -18,3 +17,12 @@ def test_few_shot_prompter(): assert ( output.prompt == examples[0].prompt + " " + examples[0].response + "\n\nHello" ) + + +def test_llama_chat_prompter(): + completion = Completion(prompt="Hello", response="Hi!") + prompter = LlamaChatPrompter() + output = prompter.apply(completion) + assert output.prompt == ( + "<>\nYou are a helpful, honest and concise assistant.\n<>\n\nHello" + )