Skip to content

Commit

Permalink
silencing repe litning errors (#21)
Browse files Browse the repository at this point in the history
* silencing repe litning errors

* adding HF token to CI
  • Loading branch information
chanind authored Nov 29, 2023
1 parent 97bcbaf commit a49c9b6
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 29 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ on: [push]
jobs:
lint_test_and_build:
runs-on: ubuntu-latest
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
Expand Down
52 changes: 29 additions & 23 deletions repepo/algorithms/repe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# from repepo.core import Dataset
from typing import Any
from repepo.core import Pipeline
from repepo.repe.rep_reading_pipeline import RepReadingPipeline
from repepo.repe.rep_control_pipeline import RepControlPipeline
Expand All @@ -10,11 +11,11 @@

import torch


class Repe(BaseAlgorithm):
# TODO: linting

def __init__(self):

self.rep_token = -1
self.n_difference = 1
self.direction_method = "pca"
Expand All @@ -34,60 +35,63 @@ def run(self, pipeline: Pipeline, dataset) -> Pipeline:
# TODO: make parameter
layer_ids = [idx for idx in hidden_layers if idx % 3 == 0]



tokenizer.pad_token_id = (
0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
)
tokenizer.bos_token_id = 1

rep_reading_pipeline = RepReadingPipeline(model=model, tokenizer=tokenizer)
train_data, test_data = dataset["train"], dataset["test"]
train_data, test_data = dataset["train"], dataset["test"] # type: ignore
rep_reader = rep_reading_pipeline.get_directions(
train_data["data"],
rep_token=rep_token,
rep_token=self.rep_token,
hidden_layers=layer_ids,
n_difference=n_difference,
n_difference=self.n_difference,
train_labels=train_data["labels"],
direction_method=direction_method,
direction_method=self.direction_method,
)

rep_control_pipeline = RepControlPipeline(
model=model,
tokenizer=tokenizer,
layers=layer_ids,
block_name=block_name,
control_method=control_method,
block_name=self.block_name,
control_method=self.control_method,
)
# breakpoint()
activations = {}
# TODO: potential erros here
for layer in layer_ids:
activations[layer] = (
torch.tensor(
coeff * rep_reader.directions[layer] * rep_reader.direction_signs[layer]
).to(model.device).half()
self.coeff
* rep_reader.directions[layer]
* rep_reader.direction_signs[layer]
)
.to(model.device)
.half()
)


from functools import partial
control_outputs = partial(rep_control_pipeline,

control_outputs = partial(
rep_control_pipeline,
activations=activations,
batch_size=4,
max_new_tokens=max_new_tokens,
max_new_tokens=self.max_new_tokens,
do_sample=False,
)
)


breakpoint()
# TODO: how to format the new model so that the structure is preserved

return pipeline

if __name__ == '__main__':

if __name__ == "__main__":
from repepo.repe.repe_dataset import bias_dataset
dataset = bias_dataset()

# TODO: fix typing
dataset: Any = bias_dataset()

from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
Expand All @@ -112,12 +116,14 @@ def run(self, pipeline: Pipeline, dataset) -> Pipeline:
token=True,
cache_dir=cache_dir,
)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.pad_token_id = (
0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
)
tokenizer.bos_token_id = 1
pipeline = Pipeline(
model=model,
tokenizer=tokenizer,
prompter = IdentityPrompter(),
formatter=InstructionFormatter()
prompter=IdentityPrompter(),
formatter=InstructionFormatter(),
)
new_pipeline = Repe().run(pipeline, dataset)
new_pipeline = Repe().run(pipeline, dataset)
4 changes: 2 additions & 2 deletions repepo/repe/rep_reading_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import List, Optional, Sequence, Union

import numpy as np
import torch
Expand Down Expand Up @@ -142,7 +142,7 @@ def get_directions(
self,
train_inputs: Union[str, List[str], List[List[str]]],
rep_token: Union[str, int] = -1,
hidden_layers: Union[str, int] = -1,
hidden_layers: Union[str, int, Sequence[Union[str, int]]] = -1,
n_difference: int = 1,
batch_size: int = 8,
train_labels: List[int] = None,
Expand Down
15 changes: 11 additions & 4 deletions tests/repe/test_run_repe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
import torch
import math


def test_run_repe(model: GPTNeoXForCausalLM, tokenizer: Tokenizer) -> None:
assert model.config.name_or_path == "EleutherAI/pythia-70m"


def test_rep_readers_and_control(model: GPTNeoXForCausalLM, tokenizer: Tokenizer) -> None:
def test_rep_readers_and_control(
model: GPTNeoXForCausalLM, tokenizer: Tokenizer
) -> None:
"""
Test that the rep readers work for Pythia 70m with double precision
"""
Expand Down Expand Up @@ -46,7 +49,6 @@ def test_rep_readers_and_control(model: GPTNeoXForCausalLM, tokenizer: Tokenizer
assert rep_reader.directions is not None
assert math.isclose(rep_reader.directions[-3][0][0], 0.00074, abs_tol=1e-5)


rep_control_pipeline = RepControlPipeline(
model=model,
tokenizer=tokenizer,
Expand All @@ -68,7 +70,9 @@ def test_rep_readers_and_control(model: GPTNeoXForCausalLM, tokenizer: Tokenizer
)

assert activations[-3].shape == torch.Size([1, 512])
assert math.isclose(float(activations[-3][0][0]), 7.410049147438258e-05, abs_tol=1e-6)
assert math.isclose(
float(activations[-3][0][0]), 7.410049147438258e-05, abs_tol=1e-6
)

inputs = "12345"
control_outputs = rep_control_pipeline(
Expand All @@ -79,4 +83,7 @@ def test_rep_readers_and_control(model: GPTNeoXForCausalLM, tokenizer: Tokenizer
do_sample=False,
)

assert control_outputs[0]["generated_text"] == '123456789_1\n\n#define S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S'
assert (
control_outputs[0]["generated_text"]
== "123456789_1\n\n#define S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S_S"
)

0 comments on commit a49c9b6

Please sign in to comment.